In [None]:
from glob import glob
import numpy as np
from concurrent.futures import ThreadPoolExecutor # To share lru_cache

import sys
sys.path.insert(0, '../ocamcalib_undistort')
sys.path.insert(0, '../')
from ocamcamera import OcamCamera

import cv2
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.mplot3d import Axes3D
matplotlib.rcParams['image.cmap'] = 'gray'
plt.rcParams['figure.figsize'] = (8, 6)
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import numpy as np
from models import OmniMVS
from models import SphericalSweeping

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from os.path import join
import cv2

## Generate filename list
# with open('omnithings_train.txt', 'w') as f:
#     for i in range(1, 4097):
#         f.write(f'{i:05}.png\n')


class OmniStereoDataset(Dataset):
    """Omnidirectional Stereo Dataset.
    http://cvlab.hanyang.ac.kr/project/omnistereo/
    """

    def __init__(self, root_dir, filename_txt, transform=None, fov=220):
        self.root_dir = root_dir
        self.transform = transform
        
        # load filenames
        with open(filename_txt) as f:
            data = f.read()
        self.filenames = data.strip().split('\n')

        
        # folder name
        self.cam_list = ['cam1', 'cam2', 'cam3', 'cam4']
        self.depth_folder = 'depth_train_640'
        
        
        # load ocam calibration data and generate valid image
        self.ocams = []
        self.valids = []
        for cam in self.cam_list:
            ocam_file = join(root_dir, f'o{cam}.txt')
            self.ocams.append(OcamCamera(ocam_file, fov, show_flag=False))
            self.valids.append(self.ocams[-1].valid_area())

        
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        sample = {}
    
        filename = self.filenames[idx]
        # load images
        for i, cam in enumerate(self.cam_list):
            img_path = join(self.root_dir, cam, filename)
            sample[cam] = load_image(img_path, valid=self.valids[i])
        # load inverse depth
        depth_path = join(self.root_dir, self.depth_folder, filename)
        sample['idepth'] = load_invdepth(depth_path)
        
        if self.transform:
            sample = self.transform(sample)

        return sample
    
    
def load_invdepth(filename, min_depth=55):
    '''
    min_depth in [cm]
    '''
    invd_value = cv2.imread(filename, cv2.IMREAD_ANYDEPTH)
    invdepth = (invd_value/100.0)/(min_depth*655)+np.finfo(np.float32).eps
    invdepth *= 100 # unit conversion from cm to m
    return invdepth


def load_image(filename, gray=True, valid=None):
    img = cv2.imread(filename)
    if gray:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    else:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
    if not valid is None:
        img[valid==0] = 0
    return img




In [None]:
import argparse
parser = argparse.ArgumentParser(description='Training for OmniMVS',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('root_dir', metavar='DIR', help='path to dataset')
parser.add_argument('-t','--train-list', default='../datasets/omnithings/omnithings_train.txt',
                    type=str, help='Text file includes filenames for training')
parser.add_argument('--epochs', default=30, type=int, metavar='N', help='total epochs')
parser.add_argument('--pretrained', default=None, metavar='PATH',
                    help='path to pre-trained model')
                   
parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--ndisp', type=int, default=192, help='number of disparity')
parser.add_argument('--min_depth', type=float, default=0.55, help='minimum depth in m')
parser.add_argument('--output_width', type=int, default=640, help='output depth width')
parser.add_argument('--output_height', type=int, default=320, help='output depth height')
parser.add_argument('-j', '--workers', default=6, type=int, metavar='N', help='number of data loading workers')
parser.add_argument('--lr', '--learning-rate', default=3e-3, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum for sgd')
                    
args = parser.parse_args('../datasets/omnithings -t ./omnithings_train.txt --ndisp 28'.split()) #


In [None]:
args

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
if device.type != 'cpu':
    cudnn.benchmark = True
print("device:", device)

In [None]:
from torchvision import transforms

class ToTensor(object):
    def __init__(self):
        self.cam_list = ['cam1', 'cam2', 'cam3', 'cam4']
        self.depth = 'idepth'
        self.ToTensor = transforms.ToTensor()
    def __call__(self, sample):
        # dataloader deal with conversion for others
        sample[self.depth] = torch.from_numpy(sample[self.depth]).float()
        for cam in self.cam_list:
            sample[cam] = self.ToTensor(sample[cam])
        return sample


class Normalize(object):
    def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
        self.mean = mean
        self.std = std
        self.cam_list = ['cam1', 'cam2', 'cam3', 'cam4']

    def __call__(self, sample):
        for cam in self.cam_list:
            for t, m, s in zip(sample[cam], self.mean, self.std):
                t.sub_(m).div_(s)
        return sample
    
ToPIL = transforms.ToPILImage()
train_transform = transforms.Compose([ToTensor(), Normalize()])

In [None]:
filename_txt = args.train_list
root_dir = args.root_dir
trainset = OmniStereoDataset(root_dir, filename_txt, transform=train_transform)
print(f'{len(trainset)} samples were found.')

In [None]:
train_loader = DataLoader(trainset, args.batch_size, shuffle=True, num_workers=args.workers)

In [None]:
batch = iter(train_loader).next()
tensor = batch['cam1'][0]
plt.imshow(ToPIL(0.5+0.5*tensor))

In [None]:
invd = batch['idepth'][0]
plt.imshow(invd.numpy())

# Build model and run

In [None]:
class InvDepthConverter(object):
    def __init__(self, ndisp, invd_0, invd_max):
        self._ndisp = ndisp
        self._invd_0 = invd_0
        self._invd_max = invd_max
        
    def invdepth_to_index(self, idepth):
        invd_idx = (self._ndisp-1)*(idepth - self._invd_0)/(self._invd_max - self._invd_0)
        # Q: why round?
        invd_idx = torch.round(invd_idx)
        return invd_idx

    def index_to_invdepth(self, invd_idx):
        idepth = self.invd + invd_idx*(self._invd_max - self._invd_0)/(self._ndisp-1)
        return idepth


In [None]:
sweep = SphericalSweeping(root_dir, h=args.output_height, w=args.output_width)
model = OmniMVS(sweep, args.ndisp, args.min_depth, h=args.output_height, w=args.output_width)
invd_0 = model.inv_depths[0]
invd_max = model.inv_depths[-1]

converter = InvDepthConverter(args.ndisp, invd_0, invd_max)
model = model.to(device)
start_epoch = 0

# cache
num_cam = 4
pool = ThreadPoolExecutor(5)
futures = []
for i in range(num_cam):
    for d in model.depths[::2]:
        futures.append(pool.submit(sweep.get_grid, i, d))

In [None]:
# setup solver scheduler
print('=> setting optimizer')
optimizer = torch.optim.SGD(model.parameters(),lr=args.lr, momentum=args.momentum)
# optimizer = torch.optim.Adam(model.parameters(),lr=3e-4)

print('=> setting scheduler')
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

if args.pretrained:
    checkpoint = torch.load(args.pretrained)
    print("=> using pre-trained weights")
    model.load_state_dict(checkpoint['state_dict'])
    start_epoch = checkpoint['epoch']
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    print("=> Resume training from epoch {}".format(start_epoch))
    
print('=> wait for a while until all tasks in pool are finished')
pool.shutdown()
print('=> Done!')

In [None]:
# # cache
# with torch.no_grad():
#     for key in batch.keys():
#         batch[key] = batch[key].to(device)
# #     out = model(batch)
# # #     del out, batch # save memory

# Training

In [None]:
# # Single batch overfitting
# from tqdm.notebook import tqdm
# from collections import OrderedDict

# batch = iter(train_loader).next()

# losses = []
# pbar = tqdm(range(1000))
# for it in pbar:
#     # to cuda
#     for key in batch.keys():
#         batch[key] = batch[key].to(device)
#     pred = model(batch)

#     gt_idepth = batch['idepth']
#     # Loss function  
#     gt_invindex = converter.invdepth_to_index(gt_idepth)
#     loss = nn.L1Loss()(pred, gt_invindex)
#     losses.append(loss.item())

#     # update parameters
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     # update progress bar
#     display = OrderedDict(it=f"{it:>2}",loss=f"{losses[-1]:.4f}")
#     pbar.set_postfix(display)
    
# plt.title('Loss (log)')
# plt.plot(losses)
# plt.yscale('log')
# plt.show()

In [None]:
from tqdm.notebook import tqdm
from collections import OrderedDict
for epoch in range(start_epoch, args.epochs):
    model.train()
    losses = []
    pbar = tqdm(train_loader)
    for idx, batch in enumerate(pbar):
        # to cuda
        for key in batch.keys():
            batch[key] = batch[key].to(device)
        pred = model(batch)

        gt_idepth = batch['idepth']
        # Loss function  
        gt_invindex = converter.invdepth_to_index(gt_idepth)
        loss = nn.L1Loss()(pred, gt_invindex)
        losses.append(loss.item())

        # update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update progress bar
        display = OrderedDict(epoch=f"{epoch:>2}",loss=f"{losses[-1]:.4f}")
        pbar.set_postfix(display)
    
    # End of one epoch
    scheduler.step()
    print(f"Epoch:{epoch}, Loss average:{sum(losses)/len(losses):.4f}")
    
    save_data = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'ave_loss' : sum(losses)/len(losses),
    }
    torch.save(save_data, f'checkpoints_{epoch}.pth')

# Inference

In [None]:
loader_iter = iter(train_loader)

In [None]:
batch = loader_iter.next()
model.eval()
with torch.no_grad():
    for key in batch.keys():
        batch[key] = batch[key].to(device)
    pred = model(batch)
    gt_idepth = batch['idepth']
    gt_invd_idx = converter.invdepth_to_index(gt_idepth)
    error = torch.abs(pred-gt_invd_idx)

In [None]:
pred

In [None]:
ToPIL(pred.cpu()/args.ndisp)

In [None]:
ToPIL(gt_invd_idx.cpu()/args.ndisp)