In [17]:
import pickle
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
import numpy as np

import aug_util as aug
import wv_util as wv
from PIL import Image
import csv
import tqdm
import itertools
import glob
%matplotlib inline

fdir = '/data/zjc4/'
chip_dir = "/data/zjc4/chipped/data/"

In [4]:
#Load the class number -> class string label map
labels = {}
with open('xview_class_labels.txt') as f:
    for row in csv.reader(f):
        labels[int(row[0].split(":")[0])] = row[0].split(":")[1]
        pass
    pass

In [2]:
#Loading our labels
coords, chips, classes = wv.get_labels(fdir+'xView_train.geojson')

In [2]:

all_images = glob.glob(fdir+'train_images/*.tif')

In [5]:
grouped_classes = [[11,12],[13],[17,18,20,21],\
                   [19,23,24,25,28,29,60,61,65,26],[41,42,50,40,44,45,47,49]]
def filterClasses(chip_coords,chip_classes,grouped_classes):
    filtered_classes = list(itertools.chain.from_iterable(grouped_classes))
    mask = (np.isin(chip_classes,filtered_classes))
    chip_coords, chip_classes = chip_coords[mask], chip_classes[mask]
    
    for idx, g_cls in enumerate(grouped_classes):
        mask = (np.isin(chip_classes,g_cls))
        chip_classes[mask] = idx
    return chip_coords,chip_classes
    pass

def plotDarknetFmt(c_img,x_center,y_center,ws,hs,c_cls,szx,szy):
    fig,ax = plt.subplots(1,figsize=(10,10))
    ax.imshow(c_img)
    for didx in range(c_cls.shape[0]):
        x,y = x_center[didx]*szx,y_center[didx]*szy
        w,h = ws[didx]*szx,hs[didx]*szy
        x1,y1 = x-(w/2), y-(h/2)
        w1,h1 = w,h
        rect = patches.Rectangle((x1,y1),w1,h1,\
                                 linewidth=1,edgecolor='r',facecolor='none')
        ax.add_patch(rect)
        pass
    plt.show()
    pass

def toDarknetFmt(c_box,c_cls,c_img,debug=False):
    szx,szy,_ = c_img.shape
    c_box[:,0],c_box[:,2] = c_box[:,0]/szx,c_box[:,2]/szx
    c_box[:,1],c_box[:,3] = c_box[:,1]/szy,c_box[:,3]/szy
    xmin,ymin,xmax,ymax = c_box[:,0],c_box[:,1],c_box[:,2],c_box[:,3]
    ws,hs = (xmax-xmin), (ymax-ymin)
    x_center, y_center = xmin+(ws/2),ymin+(hs/2)
    # Visualize using mpl
    if debug:
        plotDarknetFmt(c_img,x_center,y_center,ws,hs,c_cls,szx,szy)
    result = np.vstack((c_cls,x_center,y_center,ws,hs))
    return result.T

def parseChip(c_img, c_box, c_cls,img_num,c_dir):
    # Parses chips, saves chip image, and also saves corresponding labels
    fnames = []
    for c_idx in range(c_img.shape[0]):
        c_name = "{:06}_{:02}".format(int(img_num), c_idx)
        sbox,scls,simg = \
            c_box[c_idx],c_cls[c_idx],c_img[c_idx]
        # Change chip into darknet format, and save
        result = toDarknetFmt(sbox,scls,simg)
        ff_l = "{}labels/{}.txt".format(c_dir,c_name)
        np.savetxt(ff_l, result, fmt='%i %1.6f %1.6f %1.6f %1.6f')
        # Save image to specified dir
        ff_i = "{}images/{}.jpg".format(c_dir,c_name)
        Image.fromarray(simg).save(ff_i)
        # Append file name to list
        fnames.append("{}images/{}.jpg".format(c_dir,c_name))
        pass
    return fnames

def exportChipImages(image_paths,c_dir,set_str="train"):
    #for img_pth in tqdm.tqdm(image_paths[0:1]):
    fnames = []
    for img_pth in image_paths:
        try:
            img_pth = fdir+'train_images/'+img_pth
            img_name = img_pth.split("/")[-1]
            img_num = img_name.split(".")[0]
            arr = wv.get_image(img_pth)

            chip_coords = coords[chips==img_name]
            chip_classes = classes[chips==img_name].astype(np.int64)

            chip_coords,chip_classes = \
                filterClasses(chip_coords,chip_classes,grouped_classes)

            c_img, c_box, c_cls = wv.chip_image(img=arr, coords=chip_coords, 
                                                classes=chip_classes, shape=(600,600))

            c_fnames = parseChip(c_img, c_box, c_cls, img_num, c_dir)
            fnames.extend(c_fnames)
        except FileNotFoundError:
            pass
        pass
    lines = sorted(fnames)
    print(len(lines))
    with open(fdir+"chipped/xview_img_{}.txt".format(set_str),\
                mode='w', encoding='utf-8') as myfile:
        myfile.write('\n'.join(lines))
    pass

In [18]:
string_sets = ["train","valid","test"]
data_sets = {}
for idx,str_set in enumerate(string_sets):
    with open("{}_tifs.pkl".format(str_set),"rb") as f:
        data_sets[str_set] = pickle.load(f)
        pass
    exportChipImages(data_sets[str_set],chip_dir,set_str=str_set)
    pass

8889
3769
5689


In [19]:
print("done")

done
