In [None]:
import sys
import os
sys.path.insert(0, os.path.join(os.path.expanduser("~"),"Desktop","projects", "GlacierView",
                                "src","segmentation","helpers"))
import read, preprocess, explore
from tqdm import tqdm

import rasterio
import pandas as pd

import pickle

import numpy as np
import tifffile
import geopandas as gpd
import tensorflow as tf
from datetime import date
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from PIL import Image
from scipy.stats import logistic
import cv2

import importlib
importlib.reload(read)
importlib.reload(preprocess)
importlib.reload(explore)

# TRAINING

In [None]:
#inputs
data_label = "localized_time_series_for_training_c02_t1_l2"
dem_data_label = "localized_time_series_for_segmentation_training_large"
dem_label = "NASADEM"
glacier_view_dir = os.path.join(os.path.expanduser('~'),"Desktop","projects","GlacierView")
glaciers_dir = os.path.join(glacier_view_dir,"src","earth_engine","data","ee_landing_zone",data_label, "landsat")
dem_dir = os.path.join(glacier_view_dir,"src","earth_engine", "data","ee_landing_zone",dem_data_label, "dems")
masks_dir = os.path.join(glacier_view_dir, "src","segmentation","training","data","masks_staging_2")
log_dir = os.path.join(glacier_view_dir,"src","earth_engine","data","ee_landing_zone",data_label, "logs")
output_dir = os.path.join(glacier_view_dir, "src", "earth_engine", "data", "processed_metadata", data_label)
log_path =  os.path.join(log_dir,"training_log_1.log")
glims_ids = sorted([f for f in os.listdir(glaciers_dir) if not f.startswith('.')])

#outputs
processed_training_data = os.path.join(glacier_view_dir, "src","segmentation","training","data","processed_training_data_2")
images_write_path = os.path.join(processed_training_data, "images")
masks_write_path = os.path.join(processed_training_data, "masks")

In [None]:
common_bands = ['blue','green','red','nir','swir','thermal']
dim = (128,128)

In [None]:
#completed_glims_ids = os.listdir(images_write_path)
# remaining_glims_ids = [i for i in glims_ids if i + '.tif' not in completed_glims_ids]

In [None]:
metadata_dir = os.path.join(glacier_view_dir,"src","earth_engine","data","processed_metadata",data_label)
df = pd.read_csv(os.path.join(metadata_dir,"filtered_training_data.csv"))

In [None]:
band_num = 5
for idx, row in tqdm(df.iterrows()):
    images = {}
    dems = {}
    masks = {}
    
    images[row.glims_id] = read.get_rasters(os.path.join(glaciers_dir,row.glims_id),row.file_name )
    dems[row.glims_id] = read.get_dem(os.path.join(dem_dir,row.glims_id + '_' + dem_label + '.tif'))
    
    images[row.glims_id] = preprocess.get_common_bands(images[row.glims_id],common_bands)
    images[row.glims_id] = preprocess.normalize_rasters(images[row.glims_id])
    images[row.glims_id] = preprocess.resize_rasters(images[row.glims_id],dim)
    
    dems[row.glims_id] = preprocess.normalize_rasters(dems[row.glims_id])
    dems[row.glims_id] = preprocess.resize_rasters(dems[row.glims_id], dim)

    mask_file_name = f"{row.glims_id}.tif"
    try:
        img = Image.open(os.path.join(masks_dir, mask_file_name))
    except FileNotFoundError:
        continue
    masks[row.glims_id] = {mask_file_name: np.expand_dims(np.array(img),2)}
    masks[row.glims_id] = preprocess.resize_rasters(masks[row.glims_id], dim)

    combined_to_stack = []

    image = images[row.glims_id]
    dem = dems[row.glims_id]
    mask = masks[row.glims_id]


    X = [np.concatenate((image[file_name], dem[f"{row.glims_id}_NASADEM.tif"]),axis = 2) for file_name in image]
#     if np.sum(smoothed_image == 0) < 50000: #convert to percent
#         combined_to_stack.append(smoothed_image)
    X = np.stack(X)
    tifffile.imsave(os.path.join(images_write_path,f"{row.glims_id}.tif"), X, planarconfig='contig')
    tifffile.imsave(os.path.join(masks_write_path,f"{row.glims_id}.tif"),mask[f'{row.glims_id}.tif'])    
    

In [None]:
band_num = 5
for idx, row in tqdm(df.iterrows()):
    images = {}
    dems = {}
    masks = {}
    
    images[row.glims_id] = read.get_rasters(os.path.join(glaciers_dir,row.glims_id),row.file_name )
    dems[row.glims_id] = read.get_dem(os.path.join(dem_dir,row.glims_id + '_' + dem_label + '.tif'))
    
    images[row.glims_id] = preprocess.get_common_bands(images[row.glims_id],common_bands)
    images[row.glims_id] = preprocess.normalize_rasters(images[row.glims_id])
    images[row.glims_id] = preprocess.resize_rasters(images[row.glims_id],dim)
    
    dems[row.glims_id] = preprocess.normalize_rasters(dems[row.glims_id])
    dems[row.glims_id] = preprocess.resize_rasters(dems[row.glims_id], dim)

    mask_file_name = f"{row.glims_id}.tif"
    try:
        img = Image.open(os.path.join(masks_dir, mask_file_name))
    except FileNotFoundError:
        continue
    masks[row.glims_id] = {mask_file_name: np.expand_dims(np.array(img),2)}
    masks[row.glims_id] = preprocess.resize_rasters(masks[row.glims_id], dim)

    combined_to_stack = []

    image = images[row.glims_id]
    dem = dems[row.glims_id]
    mask = masks[row.glims_id]


    X = [np.concatenate((image[file_name], dem[f"{row.glims_id}_NASADEM.tif"]),axis = 2) for file_name in image]
#     if np.sum(smoothed_image == 0) < 50000: #convert to percent
#         combined_to_stack.append(smoothed_image)
    X = np.stack(X)
    print(row.file_name)
    plt.imshow(np.rollaxis(X[0,:,:,[2,1,0]],0,3))
    plt.show()

In [None]:
X.shape

In [None]:
glims_ids = list(images.keys())
glims_id = glims_ids[1] #modify this index

file_names = list(images[glims_id].keys())
file_name = file_names[0]
img = images[glims_id][file_name]

print(f"{glims_id}/{file_name}")

#/Users/mattw/Desktop/projects/GlacierView/src/earth_engine/data/ee_landing_zone/localized_time_series_for_training_c02_t1_l2/landsat/G268658E81610N/G268658E81610N_2015-08-10_L8_C02_T1_L2_SR.tif
#G268658E81610N/G268658E81610N_2015-08-10_L8_C02_T1_L2_SR.tif

In [None]:
#explore.view_training_images(X_train, where = 0, n=100)

In [None]:
#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=1)
#training_data = list(zip(X_train,y_train))
#test_data = list(zip(X_test, y_test))