In [1]:
import random
import numpy as np
import rasterio as rio
import os, shutil, glob, json
import cv2

from PIL import Image
import torch.utils.data

from shapely.geometry import Polygon
from skimage.draw import polygon

from phobos.grain import Grain
from phobos.io import getDataLoaders
from phobos.transforms import Normalize

In [2]:
def get_image_meta(dsetdir):
    cogpath = os.path.join(dsetdir,'tiles')
    tifpath = os.path.join(dsetdir,'images') 
    annpath = os.path.join(dsetdir,'annotations') 

    meta = {}

    for apath in glob.glob(f'{annpath}/*.json'):
        key = apath.split('/')[-1].split('.')[0]
        
        tpath = os.path.join(tifpath,key)
        cpath = os.path.join(cogpath,key)
        
        meta[key] = {}

        meta[key]['ann'] = apath
        meta[key]['cog'] = cpath
        meta[key]['tif'] = tpath

    return meta

In [3]:
def get_full_load(meta):
    dates =  {
        1: 'imgs_1',
        2: 'imgs_mid_1',
        3: 'imgs_mid_2',
        4: 'imgs_mid_1',
        5: 'imgs_2'
    }

    load = {}

    for key in meta:
        load[key] = {}

        afp  = open(meta[key]['ann'],'r')
        amap = json.load(afp)

        jpath = glob.glob(f"{meta[key]['cog']}/*.jp2")[0]
        jrstr = rio.open(jpath)
        
        jparr = jrstr.read()
        jtrns = jrstr.transform 

        h,w  = jparr.shape[1:]
        mask = np.zeros((h,w),dtype=np.uint8)

        for annotation in amap['annotations']:
            poly = Polygon(annotation['geometry']['coordinates'][0])
            
            xs,ys = poly.exterior.coords.xy
            rc    = rio.transform.rowcol(jtrns,xs,ys)
            poly  = np.asarray(list(zip(rc[0],rc[1])))
            rr,cc = polygon(poly[:,0],poly[:,1],mask.shape)
            mask[rr,cc] = 1

        mask = mask.transpose().astype(np.uint8)
        print(f'mask shape : {mask.shape}')
        
        imap = {}
        for dkey in dates:
            dpath = os.path.join(meta[key]['tif'],dates[dkey])

            imgnpath = glob.glob(f'{dpath}/*B08.tif')[0]
            imgrpath = glob.glob(f'{dpath}/*B04.tif')[0]
            imggpath = glob.glob(f'{dpath}/*B03.tif')[0]
            imgbpath = glob.glob(f'{dpath}/*B02.tif')[0]

            rrstr = rio.open(imgrpath)
            grstr = rio.open(imggpath)
            brstr = rio.open(imgbpath)
            nrstr = rio.open(imgnpath)

            rarr = cv2.resize(rrstr.read()[0].astype(np.float32),(h,w))
            garr = cv2.resize(grstr.read()[0].astype(np.float32),(h,w))
            barr = cv2.resize(brstr.read()[0].astype(np.float32),(h,w))
            narr = cv2.resize(nrstr.read()[0].astype(np.float32),(h,w))

            tile = np.stack([rarr,garr,barr,narr],axis=0).squeeze()
            print(f'tile shape : {tile.shape}')
            
            imap[dkey] = tile
        
        load[key] = {'imgs': imap, 'mask': mask}
        print(f'key : {key}, h : {h}, w : {w}')

    return load

In [4]:
def get_train_val_keys(meta, full_load, shape, args):
    p = shape
    s = args.stride
    r = args.ratio
    
    th = args.thres

    keys = list(meta.keys())
    random.shuffle(keys)

    tkeys = keys[:int(r*len(keys))]
    vkeys = keys[int(r*len(keys)):]

    tkeylist,vkeylist = [],[]

    for key in keys:
        mask = full_load[key]['mask']
        
        h,w = mask.shape
        keylist = [[key,i,j] 
                    for i in range(0,h,s) \
                        for j in range(0,w,s) \
                            if i+p<h and j+p<w \
                                and np.sum(mask[i:i+p,j:j+p]) > th]
        
        if key in tkeys:
            tkeylist.extend(keylist)
        elif key in vkeys:
            vkeylist.extend(keylist)

    return tkeylist,vkeylist

In [5]:
def get_sample(full_load,key,x,y,shape,args):
    s = shape
    load = full_load[key]
    ibands = args.input.heads['inp1']['bands']

    means = [ibands[band]['mean'] for band in ibands]
    stds  = [ibands[band]['std'] for band in ibands]

    N = Normalize(mean=means,std=stds)

    ipatchlist = []
    for date in load['imgs']:
        tile  = load['imgs'][date]
        print('tile shape',tile.shape)
        patch = tile[:,x:x+s,y:y+s].astype(np.float32)#.transpose((1,2,0))

        #patch = N.apply(patch).transpose((2,0,1))

        ipatchlist.append(patch)

    for patch in ipatchlist:
        print(f'key: {key}, x: {x}, y: {y}, shape: {patch.shape}')
    
    ipatch = np.stack(ipatchlist,axis=0)

    mask   = load['mask']
    print(f'mask shape: {mask.shape}')
    
    mpatch = mask[x:x+s,y:y+s]

    inputs = { 'inp1': ipatch }
    labels = { 'out1': mpatch }

    return inputs, labels
        

In [6]:
class OSCDDataset(torch.utils.data.Dataset):
    def __init__(self,samples,full_load,args,shape):
        random.shuffle(samples)

        self.args = args
        self.shape = shape
        self.loader = get_sample
        self.samples = samples
        self.full_load = full_load

        
    
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index) :
        key,x,y = self.samples[index]

        return get_sample(
                        full_load=self.full_load,
                        key=key,x=x,y=y,
                        shape=self.shape,
                        args=self.args
                )


            

In [7]:
args = Grain(yaml='../metadata.yaml',polyaxon_exp=None)
#args.parse_args_from_yaml('../metadata.yaml')

shape = args.input.heads['inp1'].shape.H


oscd_meta = get_image_meta(args.dataset_path)
oscd_full_load = get_full_load(oscd_meta)

tkeys, vkeys = get_train_val_keys(oscd_meta,oscd_full_load,shape, args)

print(f'number of train keys : {len(tkeys)}')
print(f'number of val keys : {len(vkeys)}\n')

mask shape : (540, 695)
tile shape : (4, 540, 695)
tile shape : (4, 540, 695)
tile shape : (4, 540, 695)
tile shape : (4, 540, 695)
tile shape : (4, 540, 695)
key : hongkong, h : 695, w : 540
mask shape : (679, 631)
tile shape : (4, 679, 631)
tile shape : (4, 679, 631)
tile shape : (4, 679, 631)
tile shape : (4, 679, 631)
tile shape : (4, 679, 631)
key : saclay_e, h : 631, w : 679
mask shape : (563, 339)
tile shape : (4, 563, 339)
tile shape : (4, 563, 339)
tile shape : (4, 563, 339)
tile shape : (4, 563, 339)
tile shape : (4, 563, 339)
key : rennes, h : 339, w : 563
mask shape : (461, 517)
tile shape : (4, 461, 517)
tile shape : (4, 461, 517)
tile shape : (4, 461, 517)
tile shape : (4, 461, 517)
tile shape : (4, 461, 517)
key : bordeaux, h : 517, w : 461
mask shape : (716, 824)
tile shape : (4, 716, 824)
tile shape : (4, 716, 824)
tile shape : (4, 716, 824)
tile shape : (4, 716, 824)
tile shape : (4, 716, 824)
key : lasvegas, h : 824, w : 716
mask shape : (476, 458)
tile shape : (4, 4

In [8]:
datasets = {
    'train': OSCDDataset(
                    samples=tkeys,
                    full_load=oscd_full_load,
                    args=args, shape=shape
             ),
    'val': OSCDDataset(
                    samples=vkeys,
                    full_load=oscd_full_load,
                    args=args, shape=shape
           )
}

loaders = getDataLoaders(
    datasets=datasets,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    distributed=args.distributed,
    load=args.load
)

In [9]:
inputs, labels = get_sample(oscd_full_load,'abudhabi',0,0,shape,args)

print(inputs['inp1'].shape)
print(labels['out1'].shape)

tile shape (4, 785, 799)
tile shape (4, 785, 799)
tile shape (4, 785, 799)
tile shape (4, 785, 799)
tile shape (4, 785, 799)
key: abudhabi, x: 0, y: 0, shape: (4, 64, 64)
key: abudhabi, x: 0, y: 0, shape: (4, 64, 64)
key: abudhabi, x: 0, y: 0, shape: (4, 64, 64)
key: abudhabi, x: 0, y: 0, shape: (4, 64, 64)
key: abudhabi, x: 0, y: 0, shape: (4, 64, 64)
mask shape: (785, 799)
(5, 4, 64, 64)
(64, 64)


In [10]:
train_dataset = datasets['train']
for inputs, labels in train_dataset:
    print(inputs['inp1'].shape)
    print(labels['out1'].shape)
    print()
    

tile shape (4, 788, 1015)
tile shape (4, 788, 1015)
tile shape (4, 788, 1015)
tile shape (4, 788, 1015)
tile shape (4, 788, 1015)
key: cupertino, x: 272, y: 368, shape: (4, 64, 64)
key: cupertino, x: 272, y: 368, shape: (4, 64, 64)
key: cupertino, x: 272, y: 368, shape: (4, 64, 64)
key: cupertino, x: 272, y: 368, shape: (4, 64, 64)
key: cupertino, x: 272, y: 368, shape: (4, 64, 64)
mask shape: (788, 1015)
(5, 4, 64, 64)
(64, 64)

tile shape (4, 540, 695)
tile shape (4, 540, 695)
tile shape (4, 540, 695)
tile shape (4, 540, 695)
tile shape (4, 540, 695)
key: hongkong, x: 464, y: 464, shape: (4, 64, 64)
key: hongkong, x: 464, y: 464, shape: (4, 64, 64)
key: hongkong, x: 464, y: 464, shape: (4, 64, 64)
key: hongkong, x: 464, y: 464, shape: (4, 64, 64)
key: hongkong, x: 464, y: 464, shape: (4, 64, 64)
mask shape: (540, 695)
(5, 4, 64, 64)
(64, 64)

tile shape (4, 563, 339)
tile shape (4, 563, 339)
tile shape (4, 563, 339)
tile shape (4, 563, 339)
tile shape (4, 563, 339)
key: rennes, x: 35

In [10]:
train_loader = loaders['train']
val_loader = loaders['val']

print('train')
for inputs, labels in train_loader:
    print(inputs['inp1'].shape)
    print(labels['out1'].shape)
    #break

print('val')
for inputs, labels in val_loader:
    print(inputs['inp1'].shape)
    print(labels['out1'].shape)



train
key: aguasclaras, x: 64, y: 128, shape: (4, 64, 64)key: abudhabi, x: 240, y: 400, shape: (4, 64, 64)key: lasvegas, x: 368, y: 528, shape: (4, 64, 64)key: beirut, x: 432, y: 944, shape: (4, 64, 64)



key: aguasclaras, x: 64, y: 128, shape: (4, 64, 64)key: abudhabi, x: 240, y: 400, shape: (4, 64, 64)key: beirut, x: 432, y: 944, shape: (4, 64, 64)
key: lasvegas, x: 368, y: 528, shape: (4, 64, 64)
key: aguasclaras, x: 64, y: 128, shape: (4, 64, 64)

key: abudhabi, x: 240, y: 400, shape: (4, 64, 64)
key: beirut, x: 432, y: 944, shape: (4, 64, 64)key: lasvegas, x: 368, y: 528, shape: (4, 64, 64)
key: aguasclaras, x: 64, y: 128, shape: (4, 64, 64)

key: abudhabi, x: 240, y: 400, shape: (4, 64, 64)
key: beirut, x: 432, y: 944, shape: (4, 64, 64)key: lasvegas, x: 368, y: 528, shape: (4, 64, 64)key: aguasclaras, x: 64, y: 128, shape: (4, 64, 64)


key: beirut, x: 432, y: 944, shape: (4, 64, 64)


key: lasvegas, x: 368, y: 528, shape: (4, 64, 64)


key: abudhabi, x: 240, y: 400, shape: (4,