In [None]:
!pip install Pillow

In [None]:
!apt-get update
!apt-get install gdal-bin python-gdal python3-gdal -y
!pip install GDAL

In [None]:
import matplotlib.pyplot as plt

import matplotlib.image as mpimg
import matplotlib.gridspec as gridspec
import gdal
import osgeo.gdalnumeric as gdn
import numpy as np

patches_path = '/host_pwd/work/srgan/patches/'
nrows = 2

bad_ids = set()

def plot2(date, ids_low_res,ids_high_res,patches_low_res, patches_high_res):
    
    #my_dpi = 96
    
    import matplotlib as mpl
    my_dpi = mpl.rcParams['figure.dpi']
    
    lref_inx = 0
    
    fig = plt.figure(constrained_layout=True, figsize=((nrows*400)/my_dpi, (nrows*400)/my_dpi), dpi=my_dpi)
    
    spec = gridspec.GridSpec(ncols=2, nrows=nrows, figure=fig)
    
    axes = []
    
    for row in range(0, nrows):
        axrow = []
        #a = fig.add_subplot(nrows, 2, 2*(row+1) - 1)
        a = fig.add_subplot(spec[row,0])
        #imgplot = plt.imshow(img1)
        if row==0:
            a.set_title('S2')
        
        a.set_yticklabels([])
        a.set_xticklabels([])

        axrow.append(a)

        #a = fig.add_subplot(nrows, 2, 2*(row+1))
        a = fig.add_subplot(spec[row,1])
        #imgplot = plt.imshow(img2)
        if row==0:
            a.set_title('DOF250')
            
        a.set_yticklabels([])
        a.set_xticklabels([])
        
        axrow.append(a)
        axes.append(axrow)
        
    def render(step = 1):
        nonlocal lref_inx
        
        if step < 0:
            lref_inx = lref_inx - 2*nrows
            
        if lref_inx < 0:
            lref_inx = 0
            
        fig.suptitle(date+":", fontsize=13)
        
        for row in range(0, nrows):
            try:
                if lref_inx >= len(ids_low_res):
                    fig.suptitle(fig._suptitle.get_text() + " \nYou have reached the end of array.", fontsize=13)
                    break
                
                lref_id = ids_low_res[lref_inx]
                href_inx = ids_high_res.index(lref_id) 
                
                if row == 0:
                    fig.suptitle(fig._suptitle.get_text() + " " +str(lref_inx+1)+" / "+str(len(ids_low_res)), fontsize=13)
                
                plt.sca(axes[row][0])
                im = patches_low_res[(lref_inx*64):((lref_inx+1))*64,0:64,:]
                plt.imshow(im)
                plt.ylabel(lref_id)
                #print((lref_inx*64),((lref_inx+1))*64,np.shape(im))
                plt.sca(axes[row][1])
                im = patches_high_res[(href_inx*256):((href_inx+1))*256,0:256,:]
                plt.imshow(im)
                plt.ylabel(lref_id)
                #print((lref_inx*256),((lref_inx+1))*256,np.shape(im))
            
            except ValueError:
                fig.suptitle(fig._suptitle.get_text() + "\n" + f"Ni para za ID {lref_id}", fontsize=13)
            
            lref_inx = lref_inx + 1
            
    render()

    def onclick(event):
        id = event.inaxes.get_ylabel()
        if id in bad_ids:
            bad_ids.remove(id)
            fig.suptitle(fig._suptitle.get_text() + "\n" + f"{id} removed from bad_ids.", fontsize=13)
        else:
            bad_ids.add(id)
            fig.suptitle(fig._suptitle.get_text() + "\n" + f"{id} added to bad_ids.", fontsize=13)
            
        with open('bad_ids.txt', 'w') as f:
            for item in bad_ids:
                f.write("%s\n" % item)
                
        fig.suptitle(fig._suptitle.get_text() + " bad_ids saved to bad_ids.txt", fontsize=13)

    cid = fig.canvas.mpl_connect('button_press_event', onclick)
    
    def onscroll(event):
        render(int(event.step))
        
    cid = fig.canvas.mpl_connect('scroll_event', onscroll)
    

def img_to_array(input_file, dim_ordering="channels_last", dtype='uint32'):
    #https://gis.stackexchange.com/questions/32995/fully-load-raster-into-a-numpy-array/33070
    file  = gdal.Open(input_file, gdal.GA_ReadOnly)
    bands = [file.GetRasterBand(i) for i in range(1, file.RasterCount + 1)]
    arr = np.array([gdn.BandReadAsArray(band) for band in bands]).astype(dtype)
    if dim_ordering=="channels_last":
        arr = np.transpose(arr, [1, 2, 0])  # Reorders dimensions, so that channels are last
    return arr

def process(date):
    lr = img_to_array(f"{patches_path}outlabels_{date}_64x64.tif",dtype='uint32')
    ids_low_res = [el[0][0] for el in lr]
    
    hr = img_to_array(f"{patches_path}outlabels_{date}_256x256.tif",dtype='uint32')
    ids_high_res = [el[0][0] for el in hr]
    
    patches_low_res = img_to_array(f"{patches_path}outpatches_{date}_64x64.tif")
    patches_high_res = img_to_array(f"{patches_path}outpatches_{date}_256x256.tif")
    
    %matplotlib notebook
    
    plot2(date, ids_low_res,ids_high_res,patches_low_res, patches_high_res)
    

In [None]:
process('13p4')

In [None]:
process('23p5')

In [None]:
process('26p4')