<a href="https://colab.research.google.com/github/lacykaltgr/satellite-image-segmentation/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class UNet(nn.Module):
    def __init__(self, in_channels=4, out_channels=3):
        super(UNet, self).__init__()

        # Left side of the U-Net
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding='same')
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding='same')
        self.batchnorm1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding='same')
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding='same')
        self.batchnorm2 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2)

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding='same')
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding='same')
        self.batchnorm3 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(2)

        self.conv7 = nn.Conv2d(256, 512, kernel_size=3, padding='same')
        self.conv8 = nn.Conv2d(512, 512, kernel_size=3, padding='same')
        self.batchnorm4 = nn.BatchNorm2d(512)
        self.dropout4 = nn.Dropout2d(p=0.5)
        self.pool4 = nn.MaxPool2d(2)

        # Bottom of the U-Net
        self.conv9 = nn.Conv2d(512, 1024, kernel_size=3, padding='same')
        self.conv10 = nn.Conv2d(1024, 1024, kernel_size=3, padding='same')
        self.batchnorm5 = nn.BatchNorm2d(1024)
        self.dropout5 = nn.Dropout2d(p=0.5)


        # Upsampling Starts, right side of the U-Net
        self.upconv6 = nn.Conv2d(1024, 512, kernel_size=3, padding='same')
        self.conv11 = nn.Conv2d(1024, 512, kernel_size=3, padding='same')
        self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding='same')
        self.batchnorm6 = nn.BatchNorm2d(512)

        self.upconv7 = nn.Conv2d(512, 256, kernel_size=3, padding='same')
        self.conv13 = nn.Conv2d(512, 256, kernel_size=3, padding='same')
        self.conv14 = nn.Conv2d(256, 256, kernel_size=3, padding='same')
        self.batchnorm7 = nn.BatchNorm2d(256)

        self.upconv8 = nn.Conv2d(256, 128, kernel_size=3, padding='same')
        self.conv15 = nn.Conv2d(256, 128, kernel_size=3, padding='same')
        self.conv16 = nn.Conv2d(128, 128, kernel_size=3, padding='same')
        self.batchnorm8 = nn.BatchNorm2d(128)

        self.upconv9 = nn.Conv2d(128, 64, kernel_size=3, padding='same')
        self.conv17 = nn.Conv2d(128, 64, kernel_size=3, padding='same')
        self.conv18 = nn.Conv2d(64, 64, kernel_size=3, padding='same')
        self.conv19 = nn.Conv2d(64, 16, kernel_size=3, padding='same')
        self.batchnorm9 = nn.BatchNorm2d(16)

        # Output layer of the U-Net with a softmax activation
        self.conv20 = nn.Conv2d(16, out_channels, kernel_size=1)


    def forward(self, x):
        # Left side of the U-Net
        conv1 = F.relu(self.conv1(x))
        conv1 = F.relu(self.conv2(conv1))
        conv1 = self.batchnorm1(conv1)
        pool1 = self.pool1(conv1)

        conv2 = F.relu(self.conv3(pool1))
        conv2 = F.relu(self.conv4(conv2))
        conv2 = self.batchnorm2(conv2)
        pool2 = self.pool2(conv2)

        conv3 = F.relu(self.conv5(pool2))
        conv3 = F.relu(self.conv6(conv3))
        conv3 = self.batchnorm3(conv3)
        pool3 = self.pool3(conv3)

        conv4 = F.relu(self.conv7(pool3))
        conv4 = F.relu(self.conv8(conv4))
        conv4 = self.batchnorm4(conv4)
        drop4 = self.dropout4(conv4)
        pool4 = self.pool4(drop4)

        # Bottom of the U-Net
        conv5 = F.relu(self.conv9(pool4))
        conv5 = F.relu(self.conv10(conv5))
        conv5 = self.batchnorm5(conv5)
        drop5 = self.dropout5(conv5)

        # Upsampling Starts, right side of the U-Net
        intp = F.interpolate(drop5, size=drop4.shape[2:], mode='bilinear', align_corners=True)
        up6 = F.relu(self.upconv6(intp))
        merge6 = torch.cat([drop4, up6], dim=1)
        conv6 = F.relu(self.conv11(merge6))
        conv6 = F.relu(self.conv12(conv6))
        conv6 = self.batchnorm6(conv6)

        up7 = F.relu(self.upconv7(F.interpolate(conv6, size=conv3.shape[2:], mode='bilinear', align_corners=True)))
        merge7 = torch.cat([conv3, up7], dim=1)
        conv7 = F.relu(self.conv13(merge7))
        conv7 = F.relu(self.conv14(conv7))
        conv7 = self.batchnorm7(conv7)

        up8 = F.relu(self.upconv8(F.interpolate(conv7, size=conv2.shape[2:], mode='bilinear', align_corners=True)))
        merge8 = torch.cat([conv2, up8], dim=1)
        conv8 = F.relu(self.conv15(merge8))
        conv8 = F.relu(self.conv16(conv8))
        conv8 = self.batchnorm8(conv8)

        up9 = F.relu(self.upconv9(F.interpolate(conv8, size=conv1.shape[2:], mode='bilinear', align_corners=True)))
        merge9 = torch.cat([conv1, up9], dim=1)
        conv9 = F.relu(self.conv17(merge9))
        conv9 = F.relu(self.conv18(conv9))
        conv9 = F.relu(self.conv19(conv9))
        conv9 = self.batchnorm9(conv9)

        # Output layer of the U-Net with a softmax activation
        conv10 = self.conv20(conv9)

        return conv10

    def train_model(self, train_loader, valid_loader, num_epochs=100, learning_rate=1e-4, device='cuda'):
      self.to(device)
      criterion = nn.CrossEntropyLoss()
      optimizer = optim.Adam(self.parameters(), lr=learning_rate)
      train_loss = []
      valid_loss = []
      for epoch in range(num_epochs):
          epoch_train_loss = 0
          epoch_valid_loss = 0
          self.train()
          for batch in train_loader:
              optimizer.zero_grad()
              inputs, targets = batch
              inputs = inputs.to(device)
              targets = targets.to(device)
              outputs = self(inputs)
              loss = criterion(outputs, targets)
              loss.backward()
              optimizer.step()
              epoch_train_loss += loss.item()
              targets.cpu()
              outputs.cpu()
          self.eval()
          for batch in valid_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = self(inputs)
            loss = criterion(outputs, targets)
            epoch_valid_loss += loss.item()
          train_loss.append(epoch_train_loss/len(train_loader))
          valid_loss.append(epoch_valid_loss/len(valid_loader))
          print(f'Epoch {epoch+0:03}: | Train Loss: {epoch_train_loss/len(train_loader):.5f} | Validation Loss: {epoch_valid_loss/len(valid_loader):.5f}')
      return train_loss, valid_loss

In [2]:
import numpy as np
import imageio
from google.colab import drive

# Mount Google Drive to access files
drive.mount('/content/drive')


def numericalSort(value):
    import re
    numbers = re.compile(r'(\d+)')
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts

def load_images(fnames):
    d_list = []
    for fname in fnames:
        image_raw = imageio.read(fname)
        image = np.array(image_raw.get_data(0))
        d_list.append(image)
    return d_list


def load_dataset():
    import glob

    # List of file names of actual Satellite images for traininig
    filelist_trainx = sorted(glob.glob('/content/drive/MyDrive/colab/sat-images/The-Eye-in-the-Sky-dataset/sat/*.tif'), key=numericalSort)
    # List of file names of classified images for traininig
    filelist_trainy = sorted(glob.glob('/content/drive/MyDrive/colab/sat-images/The-Eye-in-the-Sky-dataset/gt/*.tif'), key=numericalSort)
    # List of file names of actual Satellite images for testing
    filelist_testx = sorted(glob.glob('/content/drive/MyDrive/colab/sat-images/The-Eye-in-the-Sky-test-data/sat_test/*.tif'), key=numericalSort)

    # Making array of all the training sat images as it is without any cropping
    x = load_images(filelist_trainx)
    y = load_images(filelist_trainy)
    x_test = load_images(filelist_testx)

    x_train = x[:-1]
    y_train = y[:-1]
    x_val = [x[-1]]
    y_val = [y[-1]]

    return x_train, y_train, x_val, y_val, x_test


Mounted at /content/drive


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader

x_train, y_train, x_val, y_val, x_test = load_dataset()

class XYDataset(Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = x_data
        self.y_data = y_data

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        x = torch.transpose(torch.tensor(self.x_data[idx].astype('uint8'), dtype=torch.float), 0, 2)
        y = torch.transpose(torch.tensor(self.y_data[idx].astype('uint8'), dtype=torch.float), 0, 2)
        return x, y

train_dataset = XYDataset(x_train, y_train)
val_dataset = XYDataset(x_val, y_val)
test_dataset = XYDataset(x_test, torch.zeros(len(x_test)))

batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [4]:
unet = UNet()
unet.train_model(train_loader, val_loader, num_epochs=10)

Epoch 000: | Train Loss: 469.75180 | Validation Loss: 322.62433
Epoch 001: | Train Loss: 451.70929 | Validation Loss: 320.64420
Epoch 002: | Train Loss: 448.31623 | Validation Loss: 313.53116
Epoch 003: | Train Loss: 445.59826 | Validation Loss: 307.60834
Epoch 004: | Train Loss: 443.52361 | Validation Loss: 303.82788
Epoch 005: | Train Loss: 442.13099 | Validation Loss: 304.66428
Epoch 006: | Train Loss: 441.56375 | Validation Loss: 297.72519
Epoch 007: | Train Loss: 439.82790 | Validation Loss: 298.10300
Epoch 008: | Train Loss: 438.81613 | Validation Loss: 291.23383
Epoch 009: | Train Loss: 437.61059 | Validation Loss: 294.64828


([469.7518005371094,
  451.70928720327527,
  448.3162348820613,
  445.5982642540565,
  443.5236088679387,
  442.13099083533655,
  441.5637465256911,
  439.8279043344351,
  438.8161292442909,
  437.610591008113],
 [322.62432861328125,
  320.6441955566406,
  313.5311584472656,
  307.60833740234375,
  303.827880859375,
  304.6642761230469,
  297.7251892089844,
  298.1029968261719,
  291.23382568359375,
  294.6482849121094])

In [9]:
for p in unet.parameters():
    print(p.shape)

torch.Size([64, 4, 3, 3])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([128, 64, 3, 3])
torch.Size([128])
torch.Size([128, 128, 3, 3])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([256, 128, 3, 3])
torch.Size([256])
torch.Size([256, 256, 3, 3])
torch.Size([256])
torch.Size([256])
torch.Size([256])
torch.Size([512, 256, 3, 3])
torch.Size([512])
torch.Size([512, 512, 3, 3])
torch.Size([512])
torch.Size([512])
torch.Size([512])
torch.Size([1024, 512, 3, 3])
torch.Size([1024])
torch.Size([1024, 1024, 3, 3])
torch.Size([1024])
torch.Size([1024])
torch.Size([1024])
torch.Size([512, 1024, 3, 3])
torch.Size([512])
torch.Size([512, 1024, 3, 3])
torch.Size([512])
torch.Size([512, 512, 3, 3])
torch.Size([512])
torch.Size([512])
torch.Size([512])
torch.Size([256, 512, 3, 3])
torch.Size([256])
torch.Size([256, 512, 3, 3])
torch.Size([256])
torch.Size([256, 256, 3, 3])
torch.Size([256])
torch.Size([256])
torch.Size([256