In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import random
import shutil

In [2]:
# Get the path and the filename of the *.npz file generated by download.py
def path_of_data(father_dir, index):
    childDirList = os.listdir(father_dir)
    path = '{}\{}'.format(father_dir, childDirList[index])
    name = childDirList[index][:-4]
    return path, name

In [3]:
# Load the *.npz file to a numpy array
def load_saved_data(path):
    data = np.load(path)
    image_data = data['matrix1']
    position = data['matrix2']
    
    return image_data, position

In [4]:
# Calculate the center of the asteroid trail
def find_center(position):
    x1, y1, x2, y2 = position
    center_position = (x1 + x2)/2, (y1 + y2)/2
    return center_position

In [5]:
# Calculate the posible boundary of the selected region
def find_boundary(image_upper_limit, center_value, random_slide, window_shape_value):
    lower_bound = min(max(0, center_value - window_shape_value/2 - random_slide), image_upper_limit - window_shape_value)
    upper_bound = max(min(image_upper_limit, center_value + window_shape_value/2 + random_slide), window_shape_value)
    return lower_bound, upper_bound - window_shape_value

In [6]:
# Randomly pick a region around the asteroid trail
def random_pick_xy(image_shape, center_position, random_slide, window_shape_value):
    x_center, y_center = center_position
    x_center, y_center = int(x_center), int(y_center)
    
    x_bound = find_boundary(image_shape[1], x_center, random_slide, window_shape_value)
    x_final = random.randint(int(x_bound[0]), int(x_bound[1]))
    
    y_bound = find_boundary(image_shape[0], y_center, random_slide, window_shape_value)
    y_final = random.randint(int(y_bound[0]), int(y_bound[1]))
    
    return x_final, y_final

In [7]:
# Refresh the data and position. Return the selected data and the asteroid
# trail position in the new picture
def data_position_refresh(data, window_shape_value, position, window_position):
    x_final, y_final = window_position
    x1_original, y1_original, x2_original, y2_original = position
    position_refresh = x1_original - x_final, y1_original - y_final, x2_original - x_final, y2_original - y_final
    
    data_refresh = data[y_final:y_final+window_shape_value, x_final:x_final+window_shape_value]
    
    return data_refresh, position_refresh

In [8]:
# Plot the data with a red box around the asteroid trail
def plot_data(name, data, box_position, pic_dir):
    x1, y1, x2, y2 = box_position    
    plt.imshow(data, cmap='gray', vmin=0, vmax=0.01)
    plt.plot([x1, x2], [y1, y1], 'r')
    plt.plot([x1, x2], [y2, y2], 'r')
    plt.plot([x1, x1], [y1, y2], 'r')
    plt.plot([x2, x2], [y1, y2], 'r')
    plt.colorbar()
    plt.savefig('{}\{}.jpg'.format(pic_dir, name), dpi=300)
    plt.clf()

In [9]:
# Save the data and asteroid position into a *.npz file
def save_data(name, index, data, position, data_dir):
    matrix2 = np.array(position)
    dataPath = '{}\{}'.format(data_dir, name)
    np.savez(dataPath, matrix1=data, matrix2=matrix2)
    print('Successfully save data for {}: {}'.format(index, name))

In [10]:
# Wrap it up. Go through the preprocess stage
def preprocess(father_dir, pic_dir, data_dir, start_index, end_index, random_slide=300, window_shape_value=1000):
    index_max = len(os.listdir(father_dir))
    end_index = min(index_max, end_index)
    
    # Iterate between index (start_index, end_index)
    for index in range(start_index, end_index):
        # Get the path and filename of the original data file (generated by download.py)
        path, name = path_of_data(father_dir, index)
        print('Start preprocessing {}: {}'.format(index, name))
        
        # Load the *.npz file, and get the position of the asteroid trail
        data, position = load_saved_data(path)
        
        # Convert the 'nan' value to 0 in the array
        data = np.nan_to_num(data)
        
        # Calculate the shape of the picture, and the center position of the trail
        image_shape = data.shape
        center_position = find_center(position)
        
        # Randomly pick a region with asteroid trail inside
        window_position = random_pick_xy(image_shape, center_position, random_slide, window_shape_value)
        
        # Refresh the data and asteroid position to new value
        data_refresh, position_refresh = data_position_refresh(data, window_shape_value, position, window_position)
        
        # Plot the data and save the data
        plot_data(name, data_refresh, position_refresh, pic_dir)
        save_data(name, index, data_refresh, position_refresh, data_dir)

In [11]:
# Randomly split the data set to train and validation. We choose a 80/20 split.
def random_split(dir_train, dir_val, dir_test):
    fileList = os.listdir(dir_train)
    for file in fileList:
        train_file = os.path.join(dir_train, file)
        val_file = os.path.join(dir_val, file)
        test_file = os.path.join(dir_test, file)
        
        random_num = random.random()
        if random_num < 0.15:
            shutil.move(train_file, test_file)
            print('File move from {} to {}'.format(train_file, test_file))
        elif random_num < 0.3:
            shutil.move(train_file, val_file)
            print('File move from {} to {}'.format(train_file, val_file))
        else:
            print('File does not move')

In [14]:
# Main Function
if __name__ == '__main__':
    preprocess('.\data', '.\pic_positive', 
               '.\\preprocess\\train\\positive', 
               start_index=0, end_index=100, 
               random_slide=300, 
               window_shape_value=1000)
    random_split('.\\preprocess\\train\\positive', 
                 '.\\preprocess\\val\\positive',
                 '.\\preprocess\\test\\positive'
                 )

Start preprocessing 0: 0-ib1901010
Successfully save data for 0: 0-ib1901010
Start preprocessing 1: 1-ib2r03020
Successfully save data for 1: 1-ib2r03020
Start preprocessing 2: 10-ib4a28020
Successfully save data for 2: 10-ib4a28020
Start preprocessing 3: 100-ibug51010
Successfully save data for 3: 100-ibug51010
Start preprocessing 4: 1000-jb3g02010
Successfully save data for 4: 1000-jb3g02010
Start preprocessing 5: 1001-jb3g05010
Successfully save data for 5: 1001-jb3g05010
Start preprocessing 6: 1004-jb4dl1010
Successfully save data for 6: 1004-jb4dl1010
Start preprocessing 7: 1006-jb5v04010
Successfully save data for 7: 1006-jb5v04010
Start preprocessing 8: 1007-jb5v04010
Successfully save data for 8: 1007-jb5v04010
Start preprocessing 9: 1008-jb6101010
Successfully save data for 9: 1008-jb6101010
Start preprocessing 10: 1009-jb6115010
Successfully save data for 10: 1009-jb6115010
Start preprocessing 11: 101-iby102010
Successfully save data for 11: 101-iby102010
Start preprocessing 

<Figure size 640x480 with 0 Axes>

In [13]:
import cleanup
cleanup.clean_up_positive()