In [1]:
from src import LoadImageData, SyntheticS2Model

import geopandas as gpd
from geoutils import grid
from shapely.geometry import box, Point

import rasterio as rio
from rasterio.plot import reshape_as_image
from rasterio import merge
from rasterio.windows import Window,from_bounds

import torch

import matplotlib.pyplot as plt
import seaborn as sns


from tqdm import tqdm
from random import sample
import numpy as np
import os,sys

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
plt.style.use("bmh")

In [4]:
cloud_mask = gpd.read_file('../data/s2_cloud_mask.geojson')

In [5]:
Xmin,Ymin,Xmax,Ymax = cloud_mask.geometry.iloc[9].bounds
g = grid.grid(Xmin+350, Xmax, Ymin, Ymax,cell_size=640,crs=32630)
Left_Corner = g.generate_point(center=False)

In [6]:
# Bigger than 128*128 decrease the result quality
delta=640
output_size = 640

In [7]:
grid_geom = Left_Corner.coords.apply(lambda p: box(p.x - delta/2, p.y - delta/2, p.x + output_size + delta/2, p.y + output_size + delta/2))

In [8]:
model_ = SyntheticS2Model()
train_model = model_.load_from_checkpoint(checkpoint_path='/scicore/home/roeoesli/valipo0000/training/marmande/model/lightning_logs/version_53548107/checkpoints/SyntheticS2Model-0.0.1_epoch=03-validation_Loss_epoch=0.010873.ckpt')

In [9]:
S2B_PATH = '../data/S2_B4328_10m_05072019.tif'
S2A_PATH = '../data/S2_B4328_10m_24082019.tif'
S1_PATH = '../data/S1_VH_VV_10m_25072019_aligned.tif'

In [10]:
S2B = rio.open(S2B_PATH)
S2A = rio.open(S2A_PATH)
S1 = rio.open(S1_PATH)
for idx in tqdm(range(0, len(grid_geom))):
    
    
    minx,miny,maxx,maxy = grid_geom[idx].bounds
    window = from_bounds(minx, miny, maxx, maxy, transform=S2B.transform)
    transform = rio.transform.from_bounds(minx, miny, maxx, maxy, 128,128)
    
    S2B_img = S2B.read(window=window ,resampling=0)
    S2A_img = S2A.read(window=window ,resampling=0)
    S1_img = S1.read(window=window ,resampling=0)
    
    
    pred_result = train_model(torch.tensor(S2B_img.reshape(1,4,128,128)), torch.tensor(S2A_img.reshape(1,4,128,128)), torch.tensor(S1_img.reshape(1,2,128,128)))
    
    OUTPUT_PATH = f'../data/inference/{minx}_{miny}_inf.tif'
    with rio.open(OUTPUT_PATH, "w",
               driver='GTiff',
               count=4,
               transform = transform,
               width=128,
               height=128,
               dtype='float32',
               crs="epsg:32630",) as output_file:
        output_file.write(pred_result.detach().numpy().squeeze())


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3210/3210 [58:04<00:00,  1.09s/it]
