In [None]:
# import open3d as o3d

In [None]:
from glob import glob
import numpy as np
from concurrent.futures import ThreadPoolExecutor # To share lru_cache

import sys
sys.path.insert(0, '../')
from os.path import join
from ocamcamera import OcamCamera

import argparse
import random
from datetime import datetime
import json
import os
import cv2
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.mplot3d import Axes3D
matplotlib.rcParams['image.cmap'] = 'gray'
plt.rcParams['figure.figsize'] = (8, 6)
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

import numpy as np
from models import OmniMVS
from models import SphericalSweeping
from dataloader import OmniStereoDataset
from dataloader import load_image, load_invdepth
from torchvision import transforms
from dataloader.custom_transforms import Resize, ToTensor, Normalize
from utils import InvDepthConverter, evaluation_metrics

In [None]:
parser = argparse.ArgumentParser(description='Training for OmniMVS',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('root_dir', metavar='DIR', help='path to dataset')
parser.add_argument('-t','--train-list', default='../datasets/omnithings/omnithings_train.txt',
                    type=str, help='Text file includes filenames for training')
parser.add_argument('--epochs', default=30, type=int, metavar='N', help='total epochs')
parser.add_argument('--pretrained', default=None, metavar='PATH',
                    help='path to pre-trained model')
                   
parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size')

parser.add_argument('--min_depth', type=float, default=0.55, help='minimum depth in m')
if False:
    # Paper setting
    parser.add_argument('--ndisp', type=int, default=192, help='number of disparity')
    parser.add_argument('--input_width', type=int, default=800, help='input image width')
    parser.add_argument('--input_height', type=int, default=768, help='input image height')
    parser.add_argument('--output_width', type=int, default=640, help='output depth width')
    parser.add_argument('--output_height', type=int, default=320, help='output depth height')
else:
    # Light weight
    parser.add_argument('--ndisp', type=int, default=64, help='number of disparity')
    parser.add_argument('--input_width', type=int, default=500, help='input image width')
    parser.add_argument('--input_height', type=int, default=480, help='input image height')
    parser.add_argument('--output_width', type=int, default=512, help='output depth width')
    parser.add_argument('--output_height', type=int, default=256, help='output depth height')
parser.add_argument('-j', '--workers', default=6, type=int, metavar='J', help='number of data loading workers')
parser.add_argument('--lr', '--learning-rate', default=3e-3, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum for sgd')
parser.add_argument('--arch', default='omni_small', type=str, help='architecture name for log folder')
parser.add_argument('--log-interval', type=int, default=1, metavar='L', help='tensorboard log interval')
                   

# a, b = 800, 768
# for it in range(b+1)[::-1]:
#     new_a = it*a/b
#     if new_a == int(new_a):
#         print(f'{new_a:.0f}', it)
# print()
# a, b = 640, 320
# for it in range(b+1)[::-1]:
#     new_a = it*a/b
#     if new_a == int(new_a) and min(new_a,it)%32==0:
#         print(f'{new_a:.0f}', it)

root_dir = '../datasets/omnithings'
file_list = '-t ./omnithings_train.txt'
# model_params = "--input_width 500 --input_height 480 --output_width 512 --output_height 256 --ndisp 64"
pretrained = "--pretrained ../lowlr_smalldisp_0111-2204/checkpoints_20.pth"
args = parser.parse_args(f'{root_dir} {file_list} {pretrained} --lr 1e-3 --ndisp 48'.split()) #

# Generate filename list
with open('omnithings_train.txt', 'w') as f:
    for i in range(1, 4097):
        f.write(f'{i:05}.png\n')
    for i in range(5121, 8241):
        f.write(f'{i:05}.png\n')
with open('omnithings_val.txt', 'w') as f:
    for i in range(8241, 10240+1):
        f.write(f'{i:05}.png\n')
with open('omnihouse_val.txt', 'w') as f:
    for i in range(1, 2560+1):
        f.write(f'{i:04}.png\n')
        

In [None]:
args

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
if device.type != 'cpu':
    cudnn.benchmark = True
print("device:", device)

# Setup model

In [None]:
sweep = SphericalSweeping(args.root_dir, h=args.output_height, w=args.output_width)
model = OmniMVS(sweep, args.ndisp, args.min_depth, h=args.output_height, w=args.output_width)
invd_0 = model.inv_depths[0]
invd_max = model.inv_depths[-1]

converter = InvDepthConverter(args.ndisp, invd_0, invd_max)
model = model.to(device)
start_epoch = 0

# cache
num_cam = 4
pool = ThreadPoolExecutor(5)
futures = []
for i in range(num_cam):
    for d in model.depths[::2]:
        futures.append(pool.submit(sweep.get_grid, i, d))

In [None]:
# setup solver scheduler
print('=> setting optimizer')
optimizer = torch.optim.SGD(model.parameters(),lr=args.lr, momentum=args.momentum)
# optimizer = torch.optim.Adam(model.parameters(),lr=3e-4)

print('=> setting scheduler')
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

if args.pretrained:
    checkpoint = torch.load(args.pretrained)
    param_check = {
        'ndisp' : model.ndisp,
        'min_depth' : model.min_depth,
        'output_width' : model.w,
        'output_height' : model.h,
    }
    for key, val in param_check.items():
        if not checkpoint[key] == val:
            print(f'Error! Key:{key} is not the same as the checkpoints')
            
    print("=> using pre-trained weights")
    model.load_state_dict(checkpoint['state_dict'])
    start_epoch = checkpoint['epoch']
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    print("=> Resume training from epoch {}".format(start_epoch))
    
timestamp = datetime.now().strftime("%m%d-%H%M")
log_folder = join('checkpoints', f'{args.arch}_{timestamp}')
print(f'=> create log folder: {log_folder}')
os.makedirs(log_folder, exist_ok=True)
with open(join(log_folder, 'args.json'), 'w') as f:
    json.dump(vars(args), f, indent=1)
writer = SummaryWriter(log_dir=log_folder)


# Dataloader

In [None]:
# setup transform
image_size = (args.input_width, args.input_height)
depth_size = (args.output_width, args.output_height)

ToPIL = lambda x: transforms.ToPILImage()(x.cpu())
train_transform = transforms.Compose([Resize(image_size, depth_size), ToTensor(), Normalize()])

In [None]:
filename_txt = args.train_list
root_dir = args.root_dir
trainset = OmniStereoDataset(root_dir, filename_txt, transform=train_transform)
print(f'{len(trainset)} samples were found.')

In [None]:
train_loader = DataLoader(trainset, args.batch_size, shuffle=True, num_workers=args.workers)
loader_iter = iter(train_loader)


In [None]:
batch = loader_iter.next()
tensor = batch['cam1'][0]
plt.imshow(ToPIL(0.5+0.5*tensor))

In [None]:
invd = batch['idepth'][0]
plt.imshow(invd.numpy())

In [None]:
print('=> wait for a while until all tasks in pool are finished')
pool.shutdown()
print('=> Done!')

In [None]:
# # cache
# with torch.no_grad():
#     for key in batch.keys():
#         batch[key] = batch[key].to(device)
# #     out = model(batch)
# # #     del out, batch # save memory

# Training

In [None]:
# # Single batch overfitting
# from tqdm.notebook import tqdm
# from collections import OrderedDict

# # collect few batch
# batchs = []
# for it in train_loader:
#     batchs.append(it)
#     if len(batchs) > 10:
#         break
        
# # Start overfitting
# model.train()
# losses = []
# pbar = tqdm(range(1000))
# for it in pbar:
#     batch = random.choice(batchs)
#     # to cuda
#     for key in batch.keys():
#         batch[key] = batch[key].to(device)
#     pred = model(batch)

#     gt_idepth = batch['idepth']
#     # Loss function
#     gt_invd_idx = converter.invdepth_to_index(gt_idepth)
#     loss = nn.L1Loss()(pred, gt_invd_idx)
#     losses.append(loss.item())

#     # update parameters
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     # update progress bar
#     display = OrderedDict(it=f"{it:>2}", loss=f"{losses[-1]:.4f}")
#     pbar.set_postfix(display)

# plt.title('Loss (log)')
# plt.plot(losses)
# plt.yscale('log')
# plt.show()

In [None]:
from tqdm.notebook import tqdm
from collections import OrderedDict
for epoch in range(start_epoch, args.epochs):
    model.train()
    losses = []
    pbar = tqdm(train_loader)
    for idx, batch in enumerate(pbar):
        # to cuda
        for key in batch.keys():
            batch[key] = batch[key].to(device)
        pred = model(batch)

        gt_idepth = batch['idepth']
        # Loss function  
        gt_invd_idx = converter.invdepth_to_index(gt_idepth)
        loss = nn.L1Loss()(pred, gt_invd_idx)
        losses.append(loss.item())

        # update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update progress bar
        display = OrderedDict(epoch=f"{epoch:>2}",loss=f"{losses[-1]:.4f}")
        pbar.set_postfix(display)
        
        # tensorboard log
        if idx % args.log_interval == 0:
            niter = epoch*len(train_loader)+idx
            writer.add_scalar('train/loss', loss.item(), niter)
        if idx % 100*args.log_interval == 0:
            niter = epoch*len(train_loader)+idx
            imgs = []
            for cam in model.cam_list:
                imgs.append(0.5*batch[cam][0]+0.5)
            img_grid = make_grid(imgs, nrow=2, padding=5, pad_value=1)
            writer.add_image('train/fisheye', img_grid, niter)
            writer.add_image('train/pred', pred/model.ndisp, niter)
            writer.add_image('train/gt', gt_invd_idx/model.ndisp, niter)
    
    # End of one epoch
    scheduler.step()
    ave_loss = sum(losses)/len(losses)
    writer.add_scalar('train/loss_ave', ave_loss, epoch)
    print(f"Epoch:{epoch}, Loss average:{ave_loss:.4f}")
    
    save_data = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'ave_loss' : ave_loss,
        'ndisp' : model.ndisp,
        'min_depth' : model.min_depth,
        'output_width' : model.w,
        'output_height' : model.h,
    }
    torch.save(save_data, join(log_folder, f'checkpoints_{epoch}.pth'))
    
#     plt.title(f'epoch {epoch}:Loss (log)')
#     plt.plot(losses)
#     plt.yscale('log')
#     plt.show()


# Inference

In [None]:
from torchvision.utils import make_grid

In [None]:
if False:
    root_dir = '../datasets/omnithings'
    filename_list = 'omnithings_val.txt'
else:
    root_dir = '../datasets/omnihouse'
    filename_list = 'omnihouse_val.txt'
    
valset = OmniStereoDataset(root_dir, filename_list, transform=train_transform)
val_loader = DataLoader(valset, args.batch_size, shuffle=False, num_workers=args.workers)
loader_iter = iter(val_loader)
# loader_iter = iter(train_loader)
print(filename_list)
print('val_loader')

In [None]:
# batch = batchs[3]#
batch = loader_iter.next()

In [None]:
model.eval()
with torch.no_grad():
    for key in batch.keys():
        batch[key] = batch[key].to(device)
    pred = model(batch)
    gt_idepth = batch['idepth']
    gt_invd_idx = converter.invdepth_to_index(gt_idepth, round_value=False)

In [None]:
imgs = []
for cam in model.cam_list:
    imgs.append(0.5*batch[cam][0]+0.5)
img_grid = ToPIL(make_grid(imgs, padding=5, pad_value=1))

pred_vis = ToPIL(pred/args.ndisp)
gt_vis = ToPIL(gt_invd_idx/args.ndisp)

In [None]:
cmap='viridis'
fig, ax = plt.subplots(3, 1, figsize=(12,12), subplot_kw=({"xticks":(), "yticks":()}))
ax[0].set_title('fisheye images')
ax[0].imshow(img_grid)
ax[1].set_title('prediction')
ax[1].imshow(pred_vis, cmap=cmap)
ax[2].set_title('groudtruth')
ax[2].imshow(gt_vis, cmap=cmap)

# Error metrics

In [None]:
from tqdm.notebook import tqdm
preds = []
gts = []
total = 512 # len(val_loader)


model.eval()
for idx, batch in tqdm(enumerate(val_loader), total=total):
    
    with torch.no_grad():
        for key in batch.keys():
            batch[key] = batch[key].to(device)
        pred = model(batch)
        gt_idepth = batch['idepth']
        gt_invd_idx = converter.invdepth_to_index(gt_idepth, round_value=False)
        preds.append(pred.cpu())
        gts.append(gt_invd_idx.cpu())

    if len(gts) >= total:
        preds = torch.cat(preds)
        gts = torch.cat(gts)
        break
        
errors, error_names = evaluation_metrics(preds, gts, args.ndisp)
print("Error: ")
print("{:>8}, {:>8}, {:>8}, {:>8}, {:>8}".format(*error_names))
print("{:8.4f}, {:8.4f}, {:8.4f}, {:8.4f}, {:8.4f}".format(*errors))

# ----------Experimental from here -----------------------

# Real images

In [None]:
from scipy.spatial.transform import Rotation as Rot
def convertPoseToOmniMVS(Twcs, filename):
    rot_t_vecs = []
    for T in Twcs:
        # rot
        R = T[:3,:3]
        rotvec = Rot.from_matrix(R).as_rotvec()
        # tvec m -> cm
        tvec = T[:3, 3]*100
        rot_t_vecs.append(np.concatenate((rotvec, tvec)))
    rot_t_vecs = np.stack(rot_t_vecs)
    np.savetxt(filename, rot_t_vecs, fmt='%.5f')

## Convert to OmniMVS format

In [None]:
# Load camera poses
data_folder = "../real_data/"
fov = 185
fs_read = cv2.FileStorage(join(data_folder, "final_camera_poses.yml"), cv2.FILE_STORAGE_READ)
Twcs = []
for i in range(4):
    # get world <- cam transformation
    key = f'originimg{i}'
    Twcs.append(fs_read.getNode(key).mat())
    
# ocamcalib filenames in data_folder
ocam_files = [
    'calib_results_0.txt',
    'calib_results_1.txt',
    'calib_results_2.txt',
    'calib_results_3.txt'
]
img_files = [
    'img0.jpg',
    'img1.jpg',
    'img2.jpg',
    'img3.jpg'
]

In [None]:
# convert to OmniMVS format
convertPoseToOmniMVS(Twcs, join(data_folder, 'poses.txt'))

# convert to OmniMVS filename
import shutil
for i, it in enumerate(ocam_files):
    src = join(data_folder, it)
    dst = join(data_folder, f'ocam{i+1}.txt')
    shutil.copy(src, dst)

## Change sweeping module

In [None]:
new_sweep = SphericalSweeping(data_folder, h=model.h, w=model.w, fov=fov)
model.sweep = new_sweep

## Forward

In [None]:
transform = transforms.Compose([Resize((500, 500), depth_size), ToTensor(), Normalize()])

In [None]:
# Load images
batch = {}
for i in range(4):
    cam = model.cam_list[i]
    fname = join(data_folder, img_files[i])
    valid = model.sweep.valid_area(i)
    batch[cam] = load_image(fname, valid=valid)
    
batch = transform(batch)

In [None]:
model.eval()
with torch.no_grad():
    for key in batch.keys():
        batch[key] = batch[key].to(device)
        if batch[key].dim() == 3:
            batch[key].unsqueeze_(0)
    pred = model(batch)


In [None]:
imgs = []
for cam in model.cam_list:
    imgs.append(0.5*batch[cam][0]+0.5)
img_grid = ToPIL(make_grid(imgs, padding=5, pad_value=1))

pred_vis = ToPIL(pred/args.ndisp)
# gt_vis = ToPIL(gt_invd_idx/args.ndisp)

In [None]:
img_grid

In [None]:
pred_vis

# Check activations

In [None]:
# model.forward??

In [None]:
from functools import partial

In [None]:
vis_inputs = {}
vis_outputs = {}

def vis_hook(m, i, o, name):
    vis_inputs[name] = i[0]
    vis_outputs[name] = o
    
# add hook
model.transference.register_forward_hook(partial(vis_hook, name='transference'))
model.fusion.register_forward_hook(partial(vis_hook, name='fusion'))
model.cost_regularization.register_forward_hook(partial(vis_hook, name='cost_reg'))

In [None]:
model.eval()
with torch.no_grad():
    for key in batch.keys():
        batch[key] = batch[key].to(device)
    pred = model(batch)
    gt_idepth = batch['idepth']
    gt_invd_idx = converter.invdepth_to_index(gt_idepth)
    error = torch.abs(pred-gt_invd_idx)

In [None]:
vis_inputs['transference'].shape

In [None]:
invd_idx = 9
vis_tensor = vis_inputs['transference'][0, :, invd_idx].unsqueeze(1)
grid_img = make_grid(vis_tensor, padding=5, pad_value=1)
ToPIL(grid_img)

In [None]:
vis_inputs['fusion'].shape

In [None]:
invd_idx = 0
vis_tensor = vis_inputs['fusion'][0, :, invd_idx].unsqueeze(1)
grid_img = make_grid(vis_tensor, padding=5, pad_value=1, normalize=True)
ToPIL(grid_img)

In [None]:
invd_idx = 4
vis_tensor = vis_outputs['fusion'][0, :, invd_idx].unsqueeze(1)
grid_img = make_grid(vis_tensor, padding=5, pad_value=1, normalize=True)
ToPIL(grid_img)

In [None]:
vis_outputs['cost_reg'].shape

In [None]:
vis_tensor = vis_outputs['cost_reg'][0,0,:,:64,:].transpose(0, 1)
vis_tensor = vis_tensor.unsqueeze(1)
grid_img = make_grid(vis_tensor, padding=5, pad_value=1, normalize=True)
ToPIL(grid_img)