<a href="https://colab.research.google.com/github/inttx/DLAM_SealedSurfaces/blob/main/DLAM_SealedSurfaces.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from os import listdir
from os.path import isfile, join
import os
!pip install rasterio
import rasterio
from rasterio.windows import Window

import numpy as np

from torchvision.models import resnet18
from torchvision.models import ResNet18_Weights
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import *



In [None]:
import warnings
from rasterio.errors import NotGeoreferencedWarning

warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from google.colab import drive
MOUNTPOINT = '/content/dlam_datasets/'
DATADIR = os.path.join(MOUNTPOINT, 'My Drive', '2_Ortho_RGB')
drive.mount(MOUNTPOINT)



Drive already mounted at /content/dlam_datasets/; to attempt to forcibly remount, call drive.mount("/content/dlam_datasets/", force_remount=True).


In [None]:
# Number of Patches = (img_dim - patch_size) // Stride + 1
class PotsdamDataset(Dataset):
    def __init__(self, image_dir_path, label_dir_path, patch_size=256, stride=42, transform=None):
        self.image_dir = image_dir_path
        self.label_dir = label_dir_path
        self.image_files = [join(image_dir_path, f) for f in sorted(listdir(image_dir_path)) if isfile(join(image_dir_path, f)) and join(image_dir_path, f).endswith('.tif')]
        self.image_labels = [join(label_dir_path, f) for f in sorted(listdir(label_dir_path)) if isfile(join(label_dir_path, f)) and join(label_dir_path, f).endswith('.tif')]

        self.transform = transform

        self.patch_size = patch_size
        self.stride = stride
        self.index_map = []
        self._build_index()
        self.index_map = self.index_map

    def _build_index(self):
      with rasterio.open(self.image_files[0]) as img:
        height, width = img.height, img.width
        for row in range(0, height - self.patch_size + 1, self.stride):
            for col in range(0, width - self.patch_size + 1, self.stride):
                self.index_map.append((row, col))
    def __len__(self):
        return len(self.index_map)*len(self.image_files)

    def __getitem__(self, idx):
      row, col = self.index_map[idx % len(self.index_map)]
      file_idx = idx // len(self.index_map)
      with rasterio.open(self.image_files[file_idx]) as img:
        image_patch = img.read(window=Window(col, row, self.patch_size, self.patch_size))
        image_patch = torch.from_numpy(image_patch).float().to(device)

      with rasterio.open(self.image_labels[file_idx]) as img:
        label_patch = img.read(window=Window(col, row, self.patch_size, self.patch_size))
        label_patch = torch.from_numpy(label_patch).float().to(device)

      # TODO: Apply Transform

      return image_patch, label_patch


In [None]:
patch_size = 256
stride = 42
batch_size = 512
num_epochs = 2

dataset = PotsdamDataset('dlam_datasets/My Drive/2_Ortho_RGB', 'dlam_datasets/My Drive/5_Labels_all', patch_size=patch_size, stride=stride)

In [None]:
custom_resnet18 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
custom_resnet18.fc = nn.Linear(512 * 1, patch_size*patch_size*3)
custom_resnet18 = custom_resnet18.to(device)
for params in custom_resnet18.parameters():
  params.requires_grad = False
for params in custom_resnet18.fc.parameters():
  params.requires_grad = True

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer, num_epochs=2):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    for epoch in range(num_epochs):
      model.train()
      epoch_loss = 0.0
      for batch, (X, y) in enumerate(dataloader):
          print(f"{batch} / {len(dataloader)}")
          X = X.to(device)
          y = y.to(device)
          # Compute prediction and loss
          pred = model(X)
          loss = loss_fn(pred, y.reshape((y.shape[0], pred.shape[-1])))
          # Backpropagation
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()

          epoch_loss += loss.item()
      avg_loss = epoch_loss / len(dataloader)
      print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

In [None]:
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
loss_fn = nn.MSELoss()
optimizer = AdamW(custom_resnet18.parameters(), lr=0.001)
train_loop(dataloader=dataloader, model=custom_resnet18, loss_fn=loss_fn, optimizer=optimizer, num_epochs=num_epochs)

0 / 1394


KeyboardInterrupt: 