In [None]:
import sys
sys.path.append('../')
import tensorflow as tf
import numpy as np
from os import environ as env
from os import path
from dataset_utils.md_utils import md_dataset
from dataset_utils.mai_utils import mai_dataset
from dataset_utils.phoneDepth_utils import phoneDepth_dataset, decompose_train_sample_in_batches, confidence_indeces
from dataset_utils.viz_utils import visualize_multiple_imag_rows
from dataset_utils.aug_utils import color_jitter, salty_noise, random_rotation, random_crop_and_resize, cascade_functions

In [None]:
gen_data_dir = os.environ['DATA_DIR']
dataset_type = 'mb'
dataset_locations = {"md": "MegaDepth_v1",
                     "mai": "MAI2021_dataset",
                     "mb": "FTDataset"}

dataset_name = dataset_locations[dataset_type]
data_dir = path.join(gen_data_dir, dataset_name)


In [None]:
phone = 'hua'
# io_mode = 'img2depth'
# io_mode = "img2projected"
# io_mode = "img_depth2depth"
# io_mode = "img2depth_depth"
# io_mode = "img_depth2depth_depth"
# io_mode = "img2depth_conf"
io_mode = "img2depth_depth_conf"

conf_indx = confidence_indeces[io_mode]

# Define Augmentation Functions
jitter = color_jitter(0.9, brightness=0.1, contrast=0.1, saturation=0.1, hue = 0.1)
salt_noise = salty_noise(0.9, 0.01)

combined_img_augmentation = cascade_functions([jitter, salt_noise])
# combined_img_augmentation = None

crop_resize_transform = random_crop_and_resize(prob=0.9, min_size=0.6, max_size=1.0, img_shape= (224,224), center_crop=False, conf_indx=conf_indx)
rotation_aug_transform = random_rotation(0.9, 2.5)

# Cascaded geometric transformation
geometric_augmentation = cascade_functions([crop_resize_transform, rotation_aug_transform])
# geometric_augmentation = crop_resize_transform
# geometric_augmentation = None

# Note negative probability, for stability. Don't want to crop in validation
val_geometric_transform = random_crop_and_resize(prob=-1e-5, img_shape=(128,128))
val_geometric_transform = None

shuffle = True

batch_size = 8


dataset_split = 'train'

random_seed = 123

tf.random.set_seed(random_seed)
np.random.seed(random_seed)
dataset_train = phoneDepth_dataset(data_dir, mode=dataset_split, input_size=(480, 640), batch_size=8, random_flip=True, shuffle=shuffle,
                            phone=phone, io_mode=io_mode,
                            geometric_aug_transform=geometric_augmentation, img_aug_transform=combined_img_augmentation)

tf.random.set_seed(random_seed)
np.random.seed(random_seed)
dataset_val = phoneDepth_dataset(data_dir, mode=dataset_split, input_size=(320, 320), out_size=(480,640), batch_size=8, random_flip=False, shuffle=shuffle,
                            phone=phone, io_mode=io_mode)

In [None]:
seed = 1234
tf.random.set_seed(seed)
np.random.seed(seed)
data_train_it = dataset_train.as_numpy_iterator()
train_sample = data_train_it.next()
tf.random.set_seed(seed)
data_val_it = dataset_val.as_numpy_iterator()
val_sample = data_val_it.next()


train_batches = decompose_train_sample_in_batches(train_sample)
train_batches[0] = tf.cast(train_batches[0], tf.uint8)          # For Display

val_batches = decompose_train_sample_in_batches(val_sample)
val_batches[0] = tf.cast(val_batches[0], tf.uint8)              # For Display

img_batches = train_batches + val_batches
labels = ['train_{}'.format(i) for i in range(len(train_batches))] + ['val_{}'.format(i) for i in range(len(val_batches))]
visualize_multiple_imag_rows(img_batches, labels, n_samples = 8, histogram=True, color_map='viridis')


In [None]:
len(train_batches)
train_batches[1].shape