In [87]:
from osgeo import gdal, gdalconst, ogr
from skimage import img_as_ubyte
import numpy as np
from IPython.display import Image, display

import sys
sys.path.append("..")

from Utils.getPoints import pointsAsPixels
from Utils.croptest import croptest
from Utils.crop import cropimage
from Utils.plotimage import plot_predict
import Utils.myGeoTools as mgt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF 

import fiona
from pyproj import CRS
import random

In [88]:
print(torch.__version__)

2.0.1


In [89]:
sar_image_list = ["SNAP_Images/1D5D/S1A_EW_GRDM_1SDH_20230204T145728_20230204T145828_047087_05A61A_1D5D_Orb_NR_Cal_Spk_THR_SHP_EC.tif",
                  "SNAP_Images/06D2/S1A_EW_GRDM_1SDH_20230212T135259_20230212T135341_047203_05A9F4_06D2_Orb_NR_Cal_Spk_THR_SHP_EC.tif",
                  "SNAP_Images/17F5/S1A_EW_GRDM_1SDH_20230212T135159_20230212T135259_047203_05A9F4_17F5_Orb_NR_Cal_Spk_THR_SHP_EC.tif",
                  "SNAP_Images/84EF/S1A_EW_GRDM_1SDH_20230118T131123_20230118T131232_046838_059DBE_84EF_Orb_NR_Cal_Spk_THR_SHP_EC.tif",
                  "SNAP_Images/264D/S1A_EW_GRDM_1SDH_20230209T150434_20230209T150539_047160_05A886_264D_Orb_NR_Cal_Spk_THR_SHP_EC.tif",
                  "SNAP_Images/A890/S1A_EW_GRDM_1SDH_20230116T132728_20230116T132828_046809_059CC4_A890_Orb_NR_Cal_Spk_THR_SHP_EC.tif",
                  "SNAP_Images/D15B/S1A_EW_GRDM_1SDH_20230118T131019_20230118T131123_046838_059DBE_D15B_Orb_NR_Cal_Spk_THR_SHP_EC.tif",
                  "SNAP_Images/E82B/S1A_EW_GRDM_1SDH_20230209T150539_20230209T150639_047160_05A886_E82B_Orb_NR_Cal_Spk_THR_SHP_EC.tif",]

Simple CNN

In [90]:
# CNN Model
class CNN(nn.Module):
    def __init__(self, height, width):
        super(CNN, self).__init__()
        self.height = height
        self.width = width
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        fc_layers = nn.Sequential(
            nn.Linear(x.shape[1], 1),
            nn.Sigmoid()
        )
        x = fc_layers(x)
        return x

In [91]:
# Load model
model = torch.load("CNN_model.pt")
model.eval()

CNN(
  (conv_layers): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
)

In [92]:
def exportShpFromCentroid(centroids, shpPath, gt):
    """Export a shapefile from a list of centroid, after
    remapping to coordinate system of the image.
    """
    # Save centroids as point shapefile
    schema = {"geometry": "Point", "properties": {}}
    crs = CRS.from_epsg(4326)
    with fiona.open(shpPath, "w", "ESRI Shapefile", schema, crs=crs) as output:
        for centroid in centroids:
            px, py = centroid
            cx, cy = mgt.pixel2coord(gt, px, py)
            point = {"type": "Point", "coordinates": (cx, cy)}
            output.write({"geometry": point, "properties": {}})

In [138]:
for index in range(len(sar_image_list)):
    sar_image = sar_image_list[index]
    print(sar_image)

    # import shapefiles and image
    # Open the SAR data and get transform
    ds = gdal.Open(sar_image, gdalconst.GA_ReadOnly)
    image = np.asarray(ds.GetRasterBand(1).ReadAsArray())

    image = img_as_ubyte(image)
    height, width = image.shape

    gt=ds.GetGeoTransform()

    # crop
    shape = 30
    xsize, ysize = image.shape
    inp = []
    step = [250, 400, 150, 300, 250, 150, 200, 550]
    for x in range(15, image.shape[1]-15, step[index]):
        for y in range(15, image.shape[0]-15, step[index]):
            x_max = min(x + shape//2, xsize)
            x_min = max(x - shape//2, 0)
            y_max = min(y + shape//2, ysize)
            y_min = max(y - shape//2, 0)
            sub = image[y_min:y_max, x_min:x_max]
            if sub.shape == (30, 30):
                inp.append((sub, (y, x)))
    
    batch_size = 64
    train_loader = DataLoader(inp, batch_size=batch_size, shuffle=False)

    output = np.zeros_like(image)

    for images, (y, x) in train_loader:
        images = images.float()
        images = images.unsqueeze(1)
        outputs = model(images)
        pre = (outputs > 0.5).float()
        for i in range(len(pre)):
            if pre[i] == 1:
                output[y, x] = 1
    print(np.sum(output))
    exportShpFromCentroid(output, "predict_shp.shp", gt)

SNAP_Images/1D5D/S1A_EW_GRDM_1SDH_20230204T145728_20230204T145828_047087_05A61A_1D5D_Orb_NR_Cal_Spk_THR_SHP_EC.tif




2961
SNAP_Images/06D2/S1A_EW_GRDM_1SDH_20230212T135259_20230212T135341_047203_05A9F4_06D2_Orb_NR_Cal_Spk_THR_SHP_EC.tif




704
SNAP_Images/17F5/S1A_EW_GRDM_1SDH_20230212T135159_20230212T135259_047203_05A9F4_17F5_Orb_NR_Cal_Spk_THR_SHP_EC.tif




5568
SNAP_Images/84EF/S1A_EW_GRDM_1SDH_20230118T131123_20230118T131232_046838_059DBE_84EF_Orb_NR_Cal_Spk_THR_SHP_EC.tif




1617
SNAP_Images/264D/S1A_EW_GRDM_1SDH_20230209T150434_20230209T150539_047160_05A886_264D_Orb_NR_Cal_Spk_THR_SHP_EC.tif




1984
SNAP_Images/A890/S1A_EW_GRDM_1SDH_20230116T132728_20230116T132828_046809_059CC4_A890_Orb_NR_Cal_Spk_THR_SHP_EC.tif




4608
SNAP_Images/D15B/S1A_EW_GRDM_1SDH_20230118T131019_20230118T131123_046838_059DBE_D15B_Orb_NR_Cal_Spk_THR_SHP_EC.tif




3712
SNAP_Images/E82B/S1A_EW_GRDM_1SDH_20230209T150539_20230209T150639_047160_05A886_E82B_Orb_NR_Cal_Spk_THR_SHP_EC.tif




324
