In [8]:
import os
import random
import argparse
import numpy as np
import pandas as pd
import json
from collections import OrderedDict
from PIL import Image
import imageio
from tqdm import tqdm

import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader

import intrinsics_utils
from loss_fn import DMPLoss
from depth_prediction_net import DispNetS
from object_motion_net import MotionVectorNet

rsize_factor = (128,416)

class DepthMotionDataset(Dataset):
    def __init__(self, mode='train', transform=None, root_dir='./',):
        self.image_list = sorted(os.listdir(f'{root_dir}/images/taichung/'))[15000:16001]
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_list) - 1

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_a, img_b = Image.open(f'{self.root_dir}/images/taichung/' + self.image_list[idx]), Image.open(f'{self.root_dir}/images/taichung/' + self.image_list[idx + 1])
        if self.transform:
            sample_a = self.transform(img_a)
            sample_b = self.transform(img_b)
        return [sample_a, sample_b]

In [9]:
seed = 100
random.seed(seed)
torch.manual_seed(seed)
cudnn.deterministic = True
PATH = './checkpoints/'

default_loss_weights = {'rgb_consistency': 1.0,
                        'ssim': 3.0,
                        'depth_consistency': 0.05,
                        'depth_smoothing': 0.05,
                        'rotation_cycle_consistency': 1e-3,
                        'translation_cycle_consistency': 5e-2,
                        'depth_variance': 0.0,
                        'motion_smoothing': 1.0,
                        'motion_drift': 0.2,
                       }
batch_size = 16
motion_field_burning_steps = 20000
epochs = 30 #90
intrinsics_mat = None
use_intrinsics = False
delete_file = True
accumulate_grad_batches = 4
metrics = 0

train_dataset = DepthMotionDataset(mode='train', transform=transforms.Compose([transforms.Resize(size=rsize_factor),
                                                                               transforms.ToTensor(),]),
                                   root_dir='./')
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           num_workers=8,
                                           drop_last = False,
                                           sampler=None,
                                           pin_memory=False,
                                          )

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

depth_net = DispNetS()
#depth_net = torch.nn.DataParallel(depth_net)
depth_net.to(device)
object_motion_net = MotionVectorNet(auto_mask=True, intrinsics=use_intrinsics, intrinsics_mat=intrinsics_mat).to(device)
#object_motion_net = torch.nn.DataParallel(object_motion_net)
object_motion_net.to(device)

loss_func = DMPLoss(default_loss_weights)
train_batches = len(train_loader)
base_step = (train_batches) // accumulate_grad_batches

optimizer = optim.Adam(list(depth_net.parameters()) + list(object_motion_net.parameters()), lr=1e-4, weight_decay=1e-4)
#scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.8, patience=5)

for epoch in range(epochs):
    for rgb_seq_images in tqdm(train_loader):
        depth_net.train()
        object_motion_net.train()
        
        rgb_seq_images[0], rgb_seq_images[1] = rgb_seq_images[0].to(device), rgb_seq_images[1].to(device)
        endpoints = {}
        optimizer.zero_grad()
        
        rgb_images = torch.cat((rgb_seq_images[0], rgb_seq_images[1]), dim=0)
        depth_images = depth_net(rgb_images)
        depth_seq_images = torch.split(depth_images, depth_images.shape[0] // 2, dim=0)
        endpoints['predicted_depth'] = depth_seq_images
        endpoints['rgb'] = rgb_seq_images
        motion_features = [torch.cat((endpoints['rgb'][0], endpoints['predicted_depth'][0]), dim=1),
                           torch.cat((endpoints['rgb'][1], endpoints['predicted_depth'][1]), dim=1)]
        motion_features_stack = torch.cat(motion_features, dim=0)
        flipped_motion_features_stack = torch.cat(motion_features[::-1], dim=0)
        pairs = torch.cat([motion_features_stack, flipped_motion_features_stack], dim=1)
        rot, trans, residual_translation, intrinsics_mat = object_motion_net(pairs)
        if motion_field_burning_steps > 0.0:
            step = base_step * epoch
            step = torch.tensor(step).type(torch.FloatTensor)
            burnin_steps = torch.tensor(motion_field_burning_steps).type(torch.FloatTensor)
            residual_translation *= torch.clamp(2 * step / burnin_steps - 1, 0.0, 1.0)
        endpoints['residual_translation'] = torch.split(residual_translation, residual_translation.shape[0] // 2, dim=0)
        endpoints['background_translation'] = torch.split(trans, trans.shape[0] // 2, dim=0)
        endpoints['rotation'] = torch.split(rot, rot.shape[0] // 2, dim=0)
        intrinsics_mat = 0.5 * sum(torch.split(intrinsics_mat, intrinsics_mat.shape[0] // 2, dim=0))
        endpoints['intrinsics_mat'] = [intrinsics_mat] * 2
        endpoints['intrinsics_mat_inv'] = [intrinsics_utils.invert_intrinsics_matrix(intrinsics_mat)] * 2
        
        loss_val = loss_func(endpoints)
        loss_val.backward()
        optimizer.step()
    #scheduler.step(metrics)
    print(f'Epoch : {epoch + 1:02d}/{epochs}, loss : {loss_val:.03f}')
    imageio.imwrite(f'./train_log/{epoch + 1:02d}.png',
                    (endpoints['residual_translation'][1][-1].cpu().detach().numpy() * 255).astype(np.uint8))
    torch.save(depth_net.state_dict(), PATH + 'depth_model.ckpt')
    torch.save(object_motion_net.state_dict(), PATH + 'object_motion_model.ckpt')

100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:35<00:00,  1.52s/it]


Epoch : 01/30, loss : 0.005


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:36<00:00,  1.53s/it]


Epoch : 02/30, loss : 0.005


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:37<00:00,  1.55s/it]


Epoch : 03/30, loss : 0.006


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:43<00:00,  1.65s/it]


Epoch : 04/30, loss : 0.024


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:44<00:00,  1.66s/it]


Epoch : 05/30, loss : 0.011


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:44<00:00,  1.66s/it]


Epoch : 06/30, loss : 0.010


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:44<00:00,  1.67s/it]


Epoch : 07/30, loss : 0.008


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:45<00:00,  1.67s/it]


Epoch : 08/30, loss : 0.008


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:44<00:00,  1.66s/it]


Epoch : 09/30, loss : 0.009


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:44<00:00,  1.66s/it]


Epoch : 10/30, loss : 0.008


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:44<00:00,  1.66s/it]


Epoch : 11/30, loss : 0.007


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:44<00:00,  1.66s/it]


Epoch : 12/30, loss : 0.019


100%|███████████████████████████████████████████████████████████████████████████████████| 63/63 [01:44<00:00,  1.66s/it]


Epoch : 13/30, loss : 0.008


 32%|██████████████████████████▎                                                        | 20/63 [00:35<01:15,  1.76s/it]


KeyboardInterrupt: 

In [17]:
endpoints['residual_translation'][0][1, :, :, 0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [-0., -0., -0.,  ..., -0., -0., -0.],
        [-0., -0., -0.,  ..., -0., -0., -0.],
        [-0., -0., -0.,  ..., -0., -0., -0.]], device='cuda:0',
       grad_fn=<SelectBackward>)

In [19]:
endpoints['background_translation'][0].shape

torch.Size([16, 3])

In [None]:
import math
def cal(h, w, pad, ker, stri, dila=1):
    return math.floor(((h + 2 * pad - dila * (ker - 1) -1) / stri) + 1), math.floor(((w + 2 * pad - dila * (ker - 1) -1) / stri) + 1)

In [None]:
cal(32, 104, 0, 1, 2)

In [None]:
def transcal(h, w, pad, ker, stri, dila=1):
    return (h - 1) * stri - 2 * pad + ker, (w - 1) * stri - 2 * pad + ker

In [None]:
transcal(32, 104, 1, 1, 2)