In [28]:
import os


In [29]:
safes = os.listdir('../../../Yiwu/SAFES/')

In [30]:
tileid_date_map = {}
for safe in safes:
    if 'SAFE' in safe:
        date = safe.split('_')[2]
        tileid = safe.split('_')[5]
        if tileid not in tileid_date_map:
            tileid_date_map[tileid] = [date]
        else:
            tileid_date_map[tileid] += [date]

In [31]:
for tileid in tileid_date_map.keys():
    dates = tileid_date_map[tileid]
    dates.sort()
    print (tileid, end=' ')
    for date in dates:
        print (date, end=' ')
    print ()

T50RQS 20151126T024032 20170228T023631 20170429T023551 20171001T023529 20171026T023801 20171210T024059 20171220T024109 20171225T024121 20180109T024049 20180213T023821 20180223T023711 20180409T023549 20180419T023549 20180613T023551 20180728T023549 
T51RTM 20151126T024032 20160623T024048 20170228T023631 20170429T023551 20171031T023819 20171210T024059 20171220T024109 20171225T024121 20180109T024049 20180208T023839 20180213T023821 20180409T023549 20180419T023549 20181001T023551 
T51RTN 20151126T024032 20170228T023631 20170429T023551 20170713T023549 20171031T023819 20171210T024059 20171220T024109 20171225T024121 20180109T024049 20180208T023839 20180213T023821 20180409T023549 20180419T023549 20180728T023549 
T50RQT 20151126T024032 20161230T024112 20170228T023631 20170429T023551 20171026T023801 20171210T024059 20171220T024109 20171225T024121 20180109T024049 20180208T023839 20180213T023821 20180223T023711 20180310T023539 20180409T023549 20180419T023549 20180514T023551 20180613T023551 20180728T

In [1]:
dates = """T50RQS 20151126T024032 20170228T023631 20171225T024121 20180728T023549 
T51RTM 20151126T024032 20170228T023631 20171225T024121 20181001T023551 
T51RTN 20151126T024032 20170228T023631 20171225T024121 20180728T023549 
T50RQT 20151126T024032 20170228T023631 20171225T024121 20181001T023551"""

In [2]:
samples = {}
for line in dates.split('\n'):
    row = line.split()
    samples[row[0]] = row[1:]

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import cv2 
import sys 
import glob
import random
from multiprocessing import Pool
from itertools import product

%matplotlib inline
import matplotlib.pyplot as plt

import os
import math

sys.path.append('../utils')
sys.path.append('../models')
from dataloaders import *
from unet_blocks import *
from metrics_and_losses import *


import rasterio

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE = 0
def w(v):
    if USE_CUDA:
        return v.cuda(DEVICE)
    return v

In [3]:
# model = torch.load('../weights/onera/3dconv_seg.pt')
model = w(UNetClassify(layers=6, init_filters=32, num_channels=13, fusion_method='mul', out_dim=1))
weights = torch.load('../../weights/onera/unet_siamese_prod_relu_inp64_13band_2dates_focal_hm_cnc_all_14_cities.pt',
                    map_location='cuda:0')
model.load_state_dict(weights)
# model.eval()

In [4]:
d1_bands = glob.glob('../../../Yiwu/SAFES/*20151126T024032*T50RQS*/GRANULE/**/IMG_DATA/*_B*.jp2')
d2_bands = glob.glob('../../../Yiwu/SAFES/*20170228T023631*T50RQS*/GRANULE/**/IMG_DATA/*_B*.jp2')

d1_bands.sort()
d2_bands.sort()

In [5]:
def read_band(band):
    return rasterio.open(band).read()[0]

def read_bands(band_paths):
    pool = Pool(26)
    bands = pool.map(read_band, band_paths)
    pool.close()
    return bands

def _match_band(two_date):
    return match_band(two_date[1], two_date[0])

def match_bands(date1, date2):
    pool = Pool(13)
    date2 = pool.map(_match_band, [[date1[i], date2[i]] for i in range(len(date1))])
    pool.close()
    return date2
        
def _resize(band):
    return cv2.resize(band, (10980, 10980))
    
def stack_bands(bands):    
    pool = Pool(26)
    bands = pool.map(_resize, bands)
    pool.close()
    pool = Pool(26)
    bands = pool.map(stretch_8bit, bands)
    pool.close()
    
    return np.stack(bands[:13]).astype(np.float32), np.stack(bands[13:]).astype(np.float32)

In [6]:
d1d2 = read_bands(d1_bands + d2_bands)

d1d2[13:] = match_bands(d1d2[:13], d1d2[13:])

d1, d2 = stack_bands(d1d2)


In [7]:
out = np.zeros((d1.shape[1], d1.shape[2]))

In [8]:
input_size = 64
batches1 = []
batches2 = []
ijs = []
for i in range(0,d1.shape[1],64):
    for j in range(0,d1.shape[2],64):
        if i+input_size <= d1.shape[1] and j+input_size <= d1.shape[2]:
            batches1.append(d1[:,i:i+input_size,j:j+input_size])
            batches2.append(d2[:,i:i+input_size,j:j+input_size])
            ijs.append([i,j])
        elif i+input_size>d1.shape[1] and j+input_size<=d1.shape[2]:
            batches1.append(d1[:,d1.shape[1]-input_size:d1.shape[1],j:j+input_size])
            batches2.append(d2[:,d2.shape[1]-input_size:d2.shape[1],j:j+input_size])
            ijs.append([d1.shape[1]-input_size,j])
        elif i+input_size<=d1.shape[1] and j+input_size>d1.shape[2]:
            batches1.append(d1[:,i:i+input_size,d1.shape[2]-input_size:d1.shape[2]])
            batches2.append(d2[:,i:i+input_size,d2.shape[2]-input_size:d2.shape[2]])
            ijs.append([i,d1.shape[2]-input_size])
        else:
            batches1.append(d1[:,d1.shape[1]-input_size:d1.shape[1],
                                 d1.shape[2]-input_size:d1.shape[2]])
            batches2.append(d2[:,d2.shape[1]-input_size:d2.shape[1],
                                 d2.shape[2]-input_size:d2.shape[2]])
            ijs.append([d1.shape[1]-input_size,d1.shape[2]-input_size])

        if len(batches1) == 120:
            inp1 = w(torch.from_numpy(np.asarray(batches1) / 255.))
            inp2 = w(torch.from_numpy(np.asarray(batches2) / 255.))
#                 print (inp1.size(),inp2.size())
            logits = model(inp1, inp2)
            pred = F.sigmoid(logits) > 0.5
            pred = pred.data.cpu().numpy()

            batches1 = []
            batches2 = []

            del inp1
            del inp2

            for c in range(len(ijs)):
                out[ijs[c][0]:ijs[c][0]+input_size,ijs[c][1]:ijs[c][1]+input_size] = pred[c]

            ijs = []



In [19]:
profile = rasterio.open(d1_bands[1]).profile
profile['dtype'] = 'uint8'
profile['driver'] = 'GTiff'
fout = rasterio.open('../../../Yiwu/cd_out/T50RQS_20151126T024032_20170228T023631.tif', 'w', **profile)
fout.write(np.asarray([out]).astype(np.uint8))
fout.close()

In [18]:
profile

{'width': 10980, 'transform': Affine(10.0, 0.0, 699960.0,
       0.0, -10.0, 3200040.0), 'blockxsize': 1024, 'count': 1, 'height': 10980, 'driver': 'GTiff', 'crs': CRS({'init': 'epsg:32650'}), 'tiled': True, 'dtype': 'uint8', 'blockysize': 1024, 'nodata': None}

In [None]:
fin = rasterio.open('../../../Yiwu/cd_out/T50RQS_20151126T024032_20170228T023631.tif')
print 