In [1]:
import rasterio
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd

from rasterio.plot import adjust_band
from rasterio.plot import show
import rasterio.features
import rasterio.warp
import rasterio.mask

In [2]:
def offline_data_augmentation(data_dir_list):
    #this function flips a image for offline data augmentation. The image is loaded, flipped horizontally and then saved as a variant
    #https://note.nkmk.me/en/python-numpy-flip-flipud-fliplr/
    for dataset in data_dir_list:
        #load list of image names in each flower category
        img_list = os.listdir(os.path.join(data_path,dataset))
        print ('Augmenting images of the dataset -'+'{}\n'.format(dataset))
        label = labels_name[dataset]
        num_img_files = len(img_list) 
        for i in range(num_img_files):
            img_name = img_list[i]
            img_filename = os.path.join(data_path,dataset,img_name)
            #rotate image 90 degrees and then save
            with rasterio.open(img_filename) as ds:
                tileA = ds.read()
                tileB = np.fliplr(tileA)
                #ensure that layers are not flipped as well
                tileC = np.flip(tileA,-1)
                tileD = np.fliplr(tileC)
                #print(flipped_tile.shape)
            with rasterio.open(
                str(data_path) + '\\' + str(dataset) +'\\' + str(dataset) + str(i) +'_B.tif',
                'w',
                driver='GTiff',
                height=tileB.shape[1],
                width=tileB.shape[2],
                count=tileB.shape[0],
                dtype=tileB.dtype,
                tileB=ds.crs,
                nodata=None,
                transform=ds.transform
                ) as dst:
                dst.write(tileB)
            with rasterio.open(
                str(data_path) + '\\' + str(dataset) +'\\' + str(dataset) + str(i) +'_C.tif',
                'w',
                driver='GTiff',
                height=tileC.shape[1],
                width=tileC.shape[2],
                count=tileC.shape[0],
                dtype=tileC.dtype,
                tileC=ds.crs,
                nodata=None,
                transform=ds.transform
                ) as dst:
                dst.write(tileC)
            with rasterio.open(
                str(data_path) + '\\' + str(dataset) +'\\' + str(dataset) + str(i) +'_D.tif',
                'w',
                driver='GTiff',
                height=tileD.shape[1],
                width=tileD.shape[2],
                count=tileD.shape[0],
                dtype=tileD.dtype,
                tileD=ds.crs,
                nodata=None,
                transform=ds.transform
                ) as dst:
                dst.write(tileD)


In [3]:
#loads each image and labels the data into a single dataset, saves the results as a csv
def load_samples(data_dir_list,fulldataset_df,csv_location):
#Loop over every directory, read each file, check for corruption and append to fulldataset_df, check shape of each tile remove if not within tolrences
    tilesprocessed = 0

    for dataset in data_dir_list:
        #load list of image names in each flower category
        img_list = os.listdir(os.path.join(data_path,dataset))
        print ('Loading the images of the dataset -'+'{}\n'.format(dataset))
        label = labels_name[dataset]
        num_img_files = len(img_list)
        num_corrupted_files = 0

        #read each file and if it is corrupted exclude it
        for i in range(num_img_files):
            img_name = img_list[i]
            img_filename = os.path.join(data_path,dataset,img_name)
            try:
                with rasterio.open(img_filename) as ds:
                    tile=ds.read()
                    #print(tile.shape)
                fulldataset_df=fulldataset_df.append({'FileName':img_filename,'Label':label,'ClassName':dataset},ignore_index=True)
                tilesprocessed+=1
            except:
                print('{} is corrupted\n'.format(img_filename))
                num_corrupted_files+=1
    print('tiles added:',tilesprocessed)
    fulldataset_df.to_csv(csv_location)
    print('csv files are saved:{}',csv_location)

In [4]:
#https://medium.com/@anuj_shah/creating-custom-data-generator-for-training-deep-learning-models-part-2-be9ad08f3f0e
#https://machinelearningmastery.com/how-to-load-large-datasets-from-directories-for-deep-learning-with-keras/
#http://www.jessicayung.com/using-generators-in-python-to-train-machine-learning-models/
#declare path directories
data_path = 'C:\Data\BinFullDroneDataset'
data_dir_list = ['noprospect', 'prospect']
print ('the data list is: ',data_dir_list)
csv_location = 'BinFullDrone_dataset.csv'

#define classes
num_classes = 5
labels_name={'noprospect':0,'prospect':1}

#declare all needed dataframes
fulldataset_df = pd.DataFrame(columns=['FileName','Label','ClassName']) 

#run the data preperation scripts
#offline_data_augmentation(data_dir_list)
load_samples(data_dir_list,fulldataset_df,csv_location)



the data list is:  ['noprospect', 'prospect']
Loading the images of the dataset -noprospect

Loading the images of the dataset -prospect

tiles added: 15552
csv files are saved:{} BinFullDrone_dataset.csv


#check that a image tile opens
img_filename = 'C:\Data\Dataset\donga\donga_9.tif'

tile = rasterio.open(img_filename)
plt.imshow(tile.read(1), cmap='pink')
plt.show()

print(tile.shape)

img_filename = 'C:\Data\Dataset\donga\donga9_B.tif'

tile = rasterio.open(img_filename)
plt.imshow(tile.read(1), cmap='pink')
plt.show()

print(tile.shape)

img_filename = 'C:\Data\Dataset\donga\donga9_C.tif'

tile = rasterio.open(img_filename)
plt.imshow(tile.read(1), cmap='pink')
plt.show()

print(tile.shape)

img_filename = 'C:\Data\Dataset\donga\donga9_D.tif'

tile = rasterio.open(img_filename)
plt.imshow(tile.read(1), cmap='pink')
plt.show()

print(tile.shape)
