In [1]:
import os
import time
import torch
import sys
import torch.optim as optim
from torchvision import transforms as T
%matplotlib notebook

In [2]:
from models import SpecialFuseNetModel
from data_manager import rgbd_gradients_dataset, rgbd_gradients_dataloader
from train import FuseNetTrainer

In [3]:
CWD             = os.getcwd()
DATASET_DIR     = os.path.join(CWD,'data/nyuv2')
print(DATASET_DIR)

/home/manor/cs236781-DeepLearning/project/master/data/nyuv2


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [5]:
rgbd_grads_ds = rgbd_gradients_dataset(root=DATASET_DIR, use_transforms=True)

In [6]:
dl_train,dl_test = rgbd_gradients_dataloader(root=DATASET_DIR,use_transforms=True)

In [7]:
# _ = plot.rgbd_gradients_dataset_first_n(dataset=rgbd_grads_ds,n=5)
print(f'Found {len(rgbd_grads_ds)} images in dataset folder.')

Found 1288 images in dataset folder.


In [8]:
sample_batch = next(iter(dl_train))
rgb_size = tuple(sample_batch['rgb'].shape[1:])
depth_size = tuple(sample_batch['depth'].shape[1:])
grads_size = tuple(sample_batch['x'].shape[1:])

In [9]:
fusenetmodel = SpecialFuseNetModel(rgb_size=rgb_size,depth_size=depth_size,grads_size=grads_size, device=device)

[I] - device=cpu
    - seed=42
    - dropout_p=0.4
    - optimizer=None
    - scheduler=None
[I] - Init SpecialFuseNet
    - warm start=True
    - BN momentum=0.1
    - dropout_p=0.4
[I] - Initialize Net.
    - Init type=xavier
    - Init gain=0.02

[I] - default optimizer set: SGD(lr=0.001,momentum=0.9,weight_decay=0.0005)
[I] - default scheduler set: StepSR(step_size=1000,gamma=0.1)


In [10]:
trainer = FuseNetTrainer(model=fusenetmodel, device=device)

In [11]:
checkpoint_file = 'checkpoints/special_fusenet'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

In [None]:
res = trainer.fit(dl_train, dl_test, early_stopping=20, print_every=10,checkpoints=checkpoint_file)

--- EPOCH 1/400 ---
train_batch (Avg. Loss 0.077: 100%|██████████| 290/290 [06:05<00:00,  1.26s/it]
test_batch (Avg. Loss 0.067: 100%|██████████| 32/32 [00:14<00:00,  2.19it/s]
train_batch (Avg. Loss 0.066: 100%|██████████| 290/290 [06:35<00:00,  1.36s/it]
test_batch (Avg. Loss 0.061: 100%|██████████| 32/32 [00:10<00:00,  2.91it/s]
train_batch (Avg. Loss 0.060: 100%|██████████| 290/290 [05:27<00:00,  1.13s/it]
test_batch (Avg. Loss 0.066: 100%|██████████| 32/32 [00:09<00:00,  3.31it/s]
train_batch (Avg. Loss 0.059: 100%|██████████| 290/290 [05:21<00:00,  1.11s/it]
test_batch (Avg. Loss 0.063: 100%|██████████| 32/32 [00:09<00:00,  3.35it/s]
train_batch (Avg. Loss 0.057: 100%|██████████| 290/290 [05:31<00:00,  1.14s/it]
test_batch (Avg. Loss 0.063: 100%|██████████| 32/32 [00:16<00:00,  1.90it/s]
train_batch (Avg. Loss 0.057: 100%|██████████| 290/290 [05:54<00:00,  1.22s/it]
test_batch (Avg. Loss 0.058: 100%|██████████| 32/32 [00:15<00:00,  2.01it/s]
train_batch (0.031):  39%|███▉      | 