In [1]:
import os
import time
import torch
import sys
import torch.optim as optim
from torchvision import transforms as T
%matplotlib notebook

In [2]:
from models import SpecialFuseNetModel
from data_manager import rgbd_gradients_dataset, rgbd_gradients_dataloader
from train import FuseNetTrainer

In [3]:
CWD             = os.getcwd()
DATASET_DIR     = os.path.join(CWD,'data/nyuv2')
print(DATASET_DIR)

C:\Users\tomav\Documents\GitHub\cs236781-project\data/nyuv2


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [5]:
rgbd_grads_ds = rgbd_gradients_dataset(root=DATASET_DIR, use_transforms=True)

In [6]:
dl_train,dl_test = rgbd_gradients_dataloader(root=DATASET_DIR,
                                             use_transforms=True)

In [7]:
# _ = plot.rgbd_gradients_dataset_first_n(dataset=rgbd_grads_ds,n=5)
print(f'Found {len(rgbd_grads_ds)} images in dataset folder.')

Found 6 images in dataset folder.


In [8]:
sample_batch = next(iter(dl_train))
rgb_size = tuple(sample_batch['rgb'].shape[1:])
depth_size = tuple(sample_batch['depth'].shape[1:])
grads_size = tuple(sample_batch['x'].shape[1:])

In [9]:
fusenetmodel = SpecialFuseNetModel(rgb_size=rgb_size,depth_size=depth_size,grads_size=grads_size, device=device)

[I] - default optimizer set: SGD(lr=0.001,momentum=0.9,weight_decay=0.0005)
[I] - default scheduler set: StepSR(step_size=1000,gamma=0.1)


In [10]:
trainer = FuseNetTrainer(model=fusenetmodel, device=device)

In [11]:
checkpoint_file = 'checkpoints/special_fusenet'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

In [None]:
res = trainer.fit(dl_train, dl_test, early_stopping=1000, print_every=10, checkpoints=checkpoint_file)

--- EPOCH 1/1000 ---
train_batch (Avg. Loss 0.014: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.33s/it]
test_batch (Avg. Loss 0.014: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.16s/it]
train_batch (Avg. Loss 0.014: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.31s/it]
test_batch (Avg. Loss 0.014: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.12s/it]
[info] - Saved checkpoint checkpoints/special_fusenet.pt at epoch 2
train_batch (Avg. Loss 0.014: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it]
test_batch (Avg. Loss 0.014: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.17s/it]
[info] - Saved checkpoint checkpoints/special_fusenet.pt at epoch 3
train_batch (Avg. Loss 0.014: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.39s/it]
tes

train_batch (Avg. Loss 0.013: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.28s/it]
test_batch (Avg. Loss 0.013: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.10s/it]
[info] - Saved checkpoint checkpoints/special_fusenet.pt at epoch 28
train_batch (Avg. Loss 0.013: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.29s/it]
test_batch (Avg. Loss 0.013: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.08s/it]
train_batch (Avg. Loss 0.013: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.32s/it]
test_batch (Avg. Loss 0.013: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.10s/it]
[info] - Saved checkpoint checkpoints/special_fusenet.pt at epoch 30
--- EPOCH 31/1000 ---
train_batch (Avg. Loss 0.013: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.31s/it]


test_batch (Avg. Loss 0.013: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.15s/it]
[info] - Saved checkpoint checkpoints/special_fusenet.pt at epoch 58
train_batch (Avg. Loss 0.013: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.34s/it]
test_batch (Avg. Loss 0.012: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.20s/it]
[info] - Saved checkpoint checkpoints/special_fusenet.pt at epoch 59
train_batch (Avg. Loss 0.013: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it]
test_batch (Avg. Loss 0.017: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.29s/it]
--- EPOCH 61/1000 ---
train_batch (Avg. Loss 0.013: 100%|██████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.47s/it]
test_batch (Avg. Loss 0.026: 100%|███████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.15s/it]
