In [1]:
import numpy as np
import glob
from osgeo import gdal, osr
import pyproj
from shutil import copyfile

import PIL
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

PIL.Image.MAX_IMAGE_PIXELS = 933120000

In [2]:
class SiameseDataset(Dataset):
    def __init__(self, base_raster, image_paths,):
        self.toTensor = transforms.Compose([
            transforms.PILToTensor(),
            transforms.Resize((2000, 3000))
        ])
        self.base_raster = base_raster
        self.base = self.toTensor(Image.open(self.base_raster))[0, :, :].float()
        self.image_paths = image_paths
        a = list()
        for path in self.image_paths:
            a.append(np.load(path[:-4]+"_affine.npy"))
        self.affine = np.vstack(a)

    def __getitem__(self, index):
        img1_path = self.image_paths[index]
        img1 = self.toTensor(Image.open(img1_path))
        img1 = torch.where(img1, 1, 255).float()
        affine = self.affine[index, :]
        # label = torch.Tensor([int(self.image_paths[index][2])])
        return self.base, img1, affine

    def __len__(self):
        return len(self.image_paths)

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2),
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        return x


In [4]:
import torch
import torch.nn as nn

class SiameseNet(nn.Module):
    def __init__(self):
        super(SiameseNet, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=11, stride=5, padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(8, 16, kernel_size=5, stride=3, padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=4),
        )

        self.fc = nn.Sequential(
            nn.Linear(256256, 65536),
            nn.ReLU(inplace=True),
            nn.Linear(65536, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 6),
        )

    def forward(self, x1, x2):
        x1 = torch.squeeze(self.conv(x1))
        x2 = torch.squeeze(self.conv(x2))
        x = torch.cat((x1, x2), dim=1).flatten()
        x = self.fc(x)
        return x.squeeze()

In [5]:
def train_siamese_net(net,
                      siamese_dataset, 
                      batch_size=1, 
                      num_epochs=1000, 
                      learning_rate=0.001, 
                      validation_split=0.2, 
                      device='cuda'):
    # Split dataset into training and validation sets
    dataset_size = len(siamese_dataset)
    val_size = int(dataset_size * validation_split)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(siamese_dataset, [train_size, val_size])

    # Create data loaders for training and validation sets
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    
    # Define loss function and optimizer
    # criterion = nn.BCEWithLogitsLoss()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    
    bestval = torch.tensor(float('inf'))
    
    # Training loop
    for epoch in range(num_epochs):
        # Train the network
        net.train()
        train_loss = 0.0
        for batch_idx, (img1, img2, label) in tqdm(enumerate(train_loader), total=len(train_loader)):
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            optimizer.zero_grad()
            output = net(img1, img2)
            loss = criterion(output.float(), label.squeeze().float())
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Evaluate the network on the validation set
        net.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_idx, (img1, img2, label) in enumerate(val_loader):
                img1, img2, label = img1.to(device), img2.to(device), label.to(device)
                output = net(img1, img2)
                loss = criterion(output.float(), label.squeeze().float())
                val_loss += loss.item()
                
            if bestval > val_loss:
                bestval = val_loss
                checkpoint = {
                    'epoch': epoch, 
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }
                torch.save(checkpoint, "checkpoint.pth")

        # Print training and validation loss for the epoch
        print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}'.format(epoch+1, num_epochs, train_loss/len(train_loader), val_loss/len(val_loader)))

    # Get the final affine transformation matrix
    net.eval()
    with torch.no_grad():
        img1, img2, _ = siamese_dataset[0]
        img1, img2 = img1.to(device), img2.to(device)
        affine_matrix = net.get_affine_matrix(img1.unsqueeze(0), img2.unsqueeze(0))
    return affine_matrix

In [7]:
trainloc = r"C:\Users\fhacesga\Desktop\FIRMsDigitizing\RECTDNN\TrainDataset\\"
base_loc = r"D:\FloodChange\BaseRaster\BaseTest.tif"

files = glob.glob(f"{trainloc}*.tif")
train_dataset = SiameseDataset(base_loc, files)

In [8]:
# Define Siamese network and move it to device
net = SiameseNet().to('cuda')
train_siamese_net(net, train_dataset)

OutOfMemoryError: CUDA out of memory. Tried to allocate 62.56 GiB (GPU 0; 8.00 GiB total capacity; 59.00 KiB already allocated; 7.14 GiB free; 2.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF