In [1]:
from src import LoadImageData, _stretch_im, PatchWiseClassModel

import geopandas as gpd
from geoutils import grid

import rasterio as rio
from rasterio.plot import reshape_as_image
from rasterio import merge
from rasterio.windows import Window,from_bounds

import torch

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from random import sample
import numpy as np
import os,sys

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
plt.style.use("bmh")

In [4]:
g = grid.grid( 339613.150,367769.118,3901246.638,3935691.740,cell_size=160,crs=32654)

In [5]:
aoi = g.generate_grid()

In [6]:
aoi.head()

Unnamed: 0,geom
0,"POLYGON ((339773.000 3901246.000, 339773.000 3..."
1,"POLYGON ((339933.000 3901246.000, 339933.000 3..."
2,"POLYGON ((340093.000 3901246.000, 340093.000 3..."
3,"POLYGON ((340253.000 3901246.000, 340253.000 3..."
4,"POLYGON ((340413.000 3901246.000, 340413.000 3..."


In [7]:
model_ = PatchWiseClassModel()
train_model = model_.load_from_checkpoint(checkpoint_path='../model/lightning_logs/version_3/checkpoints/road-classifier-0.0.1_epoch=35-validation_Loss_epoch=0.671447.ckpt')

In [8]:
INPUT_PATH = '../data/Sentinel-2_B4328_10m.tif'
for i in tqdm(range(len(aoi))):
    INPUT_IMG = rio.open(INPUT_PATH)
    minx,miny,maxx,maxy = aoi.loc[i,'geom'].bounds
    window = from_bounds(minx, miny, maxx, maxy, transform=INPUT_IMG.transform)
    transform = rio.transform.from_bounds(minx, miny, maxx, maxy, 16,16)
    input_result = INPUT_IMG.read(window=window)
    pred_result = torch.argmax(train_model(torch.tensor(input_result.reshape(1,4,16,16))))
    PRED_IMG = np.ones((1,16,16)) * pred_result.cpu().numpy()
    
    
    OUTPUT_PATH = f'../data/inference/{minx}_{miny}_inf.tif'
    with rio.open(OUTPUT_PATH, "w",
               driver='GTiff',
               count=1,
               transform = transform,
               width=16,
               height=16,
               dtype='uint8',
               crs="epsg:32654",) as output_file:
        output_file.write(PRED_IMG.astype('uint8'))

100%|█████████████████████████████████████████████████████████████████████████████████| 37625/37625 [1:27:30<00:00,  7.17it/s]


In [None]:
items = os.listdir('../data/inference/')
img_list  = [names for names in items if names.endswith(".tif")]
images = [rio.open('../data/inference/'+fname) for fname in img_list]
full_image, transform = merge.merge(images)

In [None]:
MOSAIC_PATH = '../data/inference_mosaic.tif'

In [None]:
with rio.open(MOSAIC_PATH, "w",
                   driver='GTiff',
                   count=full_image.shape[0],
                   transform = transform,
                   width=full_image.shape[2],
                   height=full_image.shape[1],
                   dtype=full_image.dtype,
                   crs="epsg:32654",) as output_file:
        output_file.write(full_image)