In [None]:
import os
import time
import torch
import torch.optim as optim
from models import SpecialFuseNetModel
from torchvision import transforms as T
from data_manager import rgbd_gradients_dataset, rgbd_gradients_dataloader
import plot
%matplotlib notebook

In [None]:
CWD             = os.getcwd()
DATASET_DIR     = os.path.join(CWD,'data/nyuv2')
print(DATASET_DIR)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
IMAGE_SIZE = (64,64)
TRAIN_TEST_RATIO = 0.9
BATCH_SIZE=4
NUM_WORKERS=4

In [None]:
tf_rgb = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize(IMAGE_SIZE),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])
tf_depth = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize(IMAGE_SIZE),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,), std=(.5,)),
])

In [None]:
rgbd_grads_ds = rgbd_gradients_dataset(root=DATASET_DIR, transforms_rgb=tf_rgb, transforms_depth=tf_depth)

In [None]:
dl_train,dl_test = rgbd_gradients_dataloader(root=DATASET_DIR,
                                             batch_size=BATCH_SIZE,
                                             num_workers=NUM_WORKERS,
                                             train_test_ration=TRAIN_TEST_RATIO,
                                             transforms_rgb=tf_rgb,
                                             transforms_depth=tf_depth)

In [None]:
# _ = plot.rgbd_gradients_dataset_first_n(dataset=rgbd_grads_ds,n=5)
print(f'Found {len(rgbd_grads_ds)} images in dataset folder.')

In [None]:
print('Measure batch generation time:')
times = []
start = time.time()
for i, _ in enumerate(dl_train):   
    end = time.time()
    times.append(end-start)
    start = time.time()
print(sum(times)/len(times))

In [None]:
sample_batch = next(iter(dl_train))
rgb_size = tuple(sample_batch['rgb'].shape[1:])
depth_size = tuple(sample_batch['depth'].shape[1:])
grads_size = tuple(sample_batch['x'].shape[1:])

In [None]:
fusenetmodel = SpecialFuseNetModel(rgb_size=rgb_size,depth_size=depth_size,grads_size=grads_size,
                                   device=device,mode='train')