# 通过ImageSampler对卫星图进行裁剪生成数据集

In [28]:
from PIL import Image, ImageFile
import cv2
import numpy as np
import pickle
import sys
sys.path.append("../")
from image_sampler import ImageSampler
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated

In [29]:
train_path = "../Datasets/dataset_original/train.png"
test_path = "../Datasets/dataset_original/test.png"
mask_path = "../Datasets/ns_road_nofootpath_T.png"
taxi_gps_path = "../Datasets/GPS_data/taxi/GPS_taxi_2.0.pkl"
bus_gps_path = "../Datasets/GPS_data/bus/GPS_bus_2.0.pkl"

# train_path = "../Datasets/test_6/22.520000-113.920000.png"
# test_path = "../Datasets/test_6/22.520000-113.920000.png"
# mask_path = "../Datasets/test_6/22.520000-113.920000.png"
# taxi_gps_path = "../Datasets/GPS_data/test.pkl"
# bus_gps_path = "../Datasets/GPS_data/test.pkl"

In [30]:
train_sat_img = Image.open(train_path)
train_sat_img = cv2.cvtColor(np.array(train_sat_img), cv2.COLOR_RGBA2BGRA)
test_sat_img = Image.open(test_path)
test_sat_img = cv2.cvtColor(np.array(test_sat_img), cv2.COLOR_RGBA2BGRA)
mask_img = Image.open(mask_path)
mask_img = cv2.cvtColor(np.array(mask_img), cv2.COLOR_RGBA2BGRA)
with open(taxi_gps_path, 'rb') as f:
        taxi_gps = pickle.load(f, encoding='bytes')
with open(bus_gps_path, 'rb') as f:
        bus_gps = pickle.load(f, encoding='bytes')

In [31]:
#对GPS数据的经纬度坐标进行了筛选和转换，使其能够对应到图片上的像素坐标
def _read_gps_data (big_gps, coordinate_ranges):
        gps_data = []
        for coordinate_range in coordinate_ranges:
            #坐标筛选
            selected_rows = big_gps[(big_gps['lon'].between(coordinate_range[0][0], coordinate_range[1][0])) & (big_gps['lat'].between(coordinate_range[1][1], coordinate_range[0][1]))]
            #转换为图片上的像素坐标,"+0.5"做到四舍五入
            selected_rows['lon'] = ((selected_rows['lon'] - coordinate_range[0][0]) / (coordinate_range[1][0] - coordinate_range[0][0]) * 1024 + 0.5).astype(int)
            selected_rows['lat'] = ((coordinate_range[0][1] - selected_rows['lat']) / (coordinate_range[0][1] - coordinate_range[1][1]) * 1024 + 0.5).astype(int)
            gps_data.append(selected_rows)
        return gps_data

In [32]:
import pandas as pd
pd.options.mode.chained_assignment = None  # Suppress the warning

train_sam = ImageSampler(train_sat_img, mask_img, "train")
train_sat_imgs, train_mask_imgs, train_coors = train_sam.images_sample()

test_sam = ImageSampler(test_sat_img, mask_img, "test")
test_sat_imgs, test_mask_imgs, test_coors = test_sam.images_sample()

In [33]:
train_gps_data = _read_gps_data(taxi_gps, train_coors)
test_gps_data = _read_gps_data(taxi_gps, test_coors)
train_gps_data_bus = _read_gps_data(bus_gps, train_coors)
test_gps_data_bus = _read_gps_data(bus_gps, test_coors)

In [34]:
#保存数据
for i in range(len(train_sat_imgs)):
    #卫星图像数据
    cv2.imwrite(f"../Datasets/dataset_template_copy/train_val/image/{i//10}_{i%10}_sat.png", train_sat_imgs[i])
    #路网数据
    cv2.imwrite(f"../Datasets/dataset_template_copy/train_val/mask/{i//10}_{i%10}_mask.png", train_mask_imgs[i])
    #坐标数据
    with open(f'../Datasets/dataset_template_copy/coordinates/{i//10}_{i%10}_gps.txt', 'w') as f:
        f.write(f"{train_coors[i]}")
    #taxi数据
    with open(f'../Datasets/dataset_template_copy/GPS/taxi/{i//10}_{i%10}_gps.pkl', 'wb') as f:
        pickle.dump(train_gps_data[i], f)
    #bus数据
    with open(f'../Datasets/dataset_template_copy/GPS/bus/{i//10}_{i%10}_gps.pkl', 'wb') as f:
        pickle.dump(train_gps_data_bus[i], f)

for i in range(len(test_sat_imgs)):#全部偏移2000行
    cv2.imwrite(f"../Datasets/dataset_template_copy/test/image_test/{i//10+2000}_{i%10}_sat.png", test_sat_imgs[i])
    cv2.imwrite(f"../Datasets/dataset_template_copy/test/mask/{i//10+2000}_{i%10}_mask.png", test_mask_imgs[i])
    #坐标数据
    with open(f'../Datasets/dataset_template_copy/coordinates/{i//10+2000}_{i%10}_gps.txt', 'w') as f:
        f.write(f"{test_coors[i]}")
    #taxi数据
    with open(f'../Datasets/dataset_template_copy/GPS/taxi/{i//10+2000}_{i%10}_gps.pkl', 'wb') as f:
        pickle.dump(test_gps_data[i], f)
    #bus数据
    with open(f'../Datasets/dataset_template_copy/GPS/bus/{i//10+2000}_{i%10}_gps.pkl', 'wb') as f:
        pickle.dump(test_gps_data_bus[i], f)