In [38]:
import glob
import os
import sys

import matplotlib
import matplotlib.pyplot as plt

params = {"axes.titlesize": "26", "xtick.labelsize": "20", "ytick.labelsize": "20"}
matplotlib.rcParams.update(params)

import numpy as np
import pandas as pd
import rasterio as rio

sys.path.append(os.path.join(os.path.abspath(""), "../"))
from utils.raster_utils import get_coord_from_raster
from utils.utils import get_sample_name


MASKS_ROOT = "/data"
ROI_LIST = ["1158_spring", "1868_summer", "1970_fall", "2017_winter"]
SENSOR = "lc"
METADATA_PATH = "/data/masks_metadata.csv"

val_fraction = 0.05
test_fraction = 0.05
lists_save_directory = "../config/dataset/lists"

In [39]:
metadata = pd.read_csv(METADATA_PATH, index_col=0)

In [40]:
metadata.head()

Unnamed: 0_level_0,filepath,filename,ROI,area,x0,y0,x1,y1,IGBP_9,IGBP_11,...,LCCS_SH_15,LCCS_LU_0,LCCS_LC_24,LCCS_LU_23,LCCS_SH_22,LCCS_LC_39,LCCS_LU_37,LCCS_SH_37,LCCS_LU_16,LCCS_SH_34
No,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,/data/ROIs1158_spring/lc_128/ROIs1158_spring_l...,ROIs1158_spring_lc_128_p81,1158_spring,128,600796.364733,6094633.0,603356.364733,6092073.0,15676,1723,...,0,0,0,0,0,0,0,0,0,0
1,/data/ROIs1158_spring/lc_128/ROIs1158_spring_l...,ROIs1158_spring_lc_128_p48,1158_spring,128,595676.364733,6095913.0,598236.364733,6093353.0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,/data/ROIs1158_spring/lc_128/ROIs1158_spring_l...,ROIs1158_spring_lc_128_p660,1158_spring,128,599516.364733,6069033.0,602076.364733,6066473.0,11127,0,...,0,0,0,0,0,0,0,0,0,0
3,/data/ROIs1158_spring/lc_128/ROIs1158_spring_l...,ROIs1158_spring_lc_128_p755,1158_spring,128,572636.364733,6063913.0,575196.364733,6061353.0,42783,0,...,0,0,0,0,0,0,0,0,0,0
4,/data/ROIs1158_spring/lc_128/ROIs1158_spring_l...,ROIs1158_spring_lc_128_p834,1158_spring,128,599516.364733,6061353.0,602076.364733,6058793.0,2148,0,...,0,0,0,0,0,0,0,0,0,0


In [41]:
metadata["sample"] = metadata["filename"].apply(lambda x: get_sample_name(x))

In [42]:
areas = metadata.groupby(["ROI","area"]).count().index.tolist()

# shuffle
np.random.seed(42)
np.random.shuffle(areas)

val_areas = areas[:int(val_fraction*len(areas))]
test_areas = areas[int(val_fraction*len(areas)):int((val_fraction+test_fraction)*len(areas))]
train_areas = areas[int((val_fraction+test_fraction)*len(areas)):]


In [43]:
print(len(train_areas))
print(len(val_areas))
print(len(test_areas))

227
12
13


In [44]:
metadata_train = metadata[metadata.set_index(['ROI','area']).index.isin(train_areas)]["sample"]
metadata_val = metadata[metadata.set_index(['ROI','area']).index.isin(val_areas)]["sample"]
metadata_test = metadata[metadata.set_index(['ROI','area']).index.isin(test_areas)]["sample"]

In [48]:
metadata_train

No
0          ROIs1158_spring_lc_p81
1          ROIs1158_spring_lc_p48
2         ROIs1158_spring_lc_p660
3         ROIs1158_spring_lc_p755
4         ROIs1158_spring_lc_p834
                   ...           
180657    ROIs2017_winter_lc_p643
180658    ROIs2017_winter_lc_p813
180659    ROIs2017_winter_lc_p702
180660    ROIs2017_winter_lc_p737
180661    ROIs2017_winter_lc_p386
Name: sample, Length: 162354, dtype: object

In [45]:
print(metadata_train.shape)
print(metadata_val.shape)
print(metadata_test.shape)

(162354,)
(8737,)
(9571,)


In [47]:
np.savetxt(os.path.join(lists_save_directory, "train.txt"), metadata_train.values, fmt="%s")
np.savetxt(os.path.join(lists_save_directory, "val.txt"), metadata_val.values, fmt="%s")
np.savetxt(os.path.join(lists_save_directory, "test.txt"), metadata_test.values, fmt="%s")