In [None]:
# from datasets.dtu import DTUDataset
from datasets import find_dataset_def
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import cv2
import numpy as np
import time
import os

torch.backends.cudnn.benchmark = True # this increases inference speed a little

In [None]:
from eval_rcmvsnet_dtu import args, testlist, Interval_Scale
from train_rcmvsnet import test_sample_depth

# Load pretrained model

In [None]:
# from models.mvsnet import CascadeMVSNet
from models.casmvsnet import *
# from utils import load_ckpt
# from inplace_abn import ABN
from torch.utils.data import DataLoader
from utils import *
from torch.utils.tensorboard import SummaryWriter


test_model_loss = cas_mvsnet_loss
args.num_view = 4
args.loadckpt ='./rc-mvsnet1/model_000010_cas.ckpt'
# numdepth = 192
# interval_scale = 1.06
# max_h = 1200
# max_w = 1600
logger = SummaryWriter('./depth_test_results')

# dataset, dataloader
MVSDataset = find_dataset_def(args.dataset)
test_dataset = MVSDataset(args.testpath, testlist, "test", args.num_view, args.numdepth, Interval_Scale,
                            max_h=args.max_h, max_w=args.max_w, fix_res=args.fix_res)
TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=1, drop_last=False)


model = CascadeMVSNet_eval(refine=False, ndepths=[int(nd) for nd in args.ndepths.split(",") if nd],
            depth_interals_ratio=[float(d_i) for d_i in args.depth_inter_r.split(",") if d_i],
            share_cr=args.share_cr,
            cr_base_chs=[int(ch) for ch in args.cr_base_chs.split(",") if ch],
            grad_method=args.grad_method)

state_dict = torch.load(args.loadckpt, map_location=torch.device("cpu"))
model.load_state_dict(state_dict['model'], strict=True)
# model = nn.DataParallel(model)
# model.cuda()
# model.eval()

In [None]:
# testing
with torch.no_grad():
    avg_test_scalars = DictAverageMeter()
    for batch_idx, sample in enumerate(TestImgLoader):
        start_time = time.time()
        global_step = batch_idx
        do_summary = global_step % 48 == 0
        loss, scalar_outputs, image_outputs = test_sample_depth(model, test_model_loss, sample, args)
        if do_summary:
            save_scalars(logger, 'test', scalar_outputs, global_step)
            save_images(logger, 'test', image_outputs, global_step)
            print("Iter {}/{}, test loss = {:.3f}, depth loss = {:.3f}, thres2mm_accu = {:.3f},thres4mm_accu = {:.3f},thres8mm_accu = {:.3f},thres2mm_error = {:.3f},thres4mm_error = {:.3f},thres8mm_error = {:.3f},time = {:3f}".format(
                                                                batch_idx,
                                                                len(TestImgLoader), loss,
                                                                scalar_outputs["depth_loss"],
                                                                scalar_outputs["thres2mm_accu"],
                                                                scalar_outputs["thres4mm_accu"],
                                                                scalar_outputs["thres8mm_accu"],
                                                                scalar_outputs["thres2mm_error"],
                                                                scalar_outputs["thres4mm_error"],
                                                                scalar_outputs["thres8mm_error"],
                                                                time.time() - start_time))
        avg_test_scalars.update(scalar_outputs)
        del scalar_outputs, image_outputs

    save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step)
    print("avg_test_scalars:", avg_test_scalars.mean())

In [None]:
with torch.no_grad():
    for batch_idx, sample in enumerate(TestImgLoader):
        sample_cuda = tocuda(sample)
        start_time = time.time()
        outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"])
        end_time = time.time()
        outputs = tensor2numpy(outputs)
        del sample_cuda
        filenames = sample["filename"]
        cams = sample["proj_matrices"]["stage{}".format(num_stage)].numpy()
        # imgs = sample["imgs"].numpy()
        imgs = sample["imgs"]
        depth_values=sample["depth_values"]
        depth_start,depth_end =depth_values[0][0],depth_values[0][-1]
        
        # print("sample[imgs]: {}".format(sample["imgs"].shape))
        # imgs = inv_normalize(sample["imgs"].squeeze()).unsqueeze(dim=0).numpy()
        # imgs = sample["imgs_raw"].numpy()
        print('Iter {}/{}, Time:{} Res:{}'.format(batch_idx, len(TestImgLoader), end_time - start_time, imgs[0].shape))

# Visualize an example depth

In [None]:
from utils import *
plt.imshow(unpreprocess(imgs[0]).permute(1,2,0))
plt.imshow(visualize_depth(depths['level_0']*masks['level_0']).permute(1,2,0), alpha=0.5)

# Do inference on this sample

In [None]:
t = time.time()
with torch.no_grad():
    results = model(imgs.unsqueeze(0).cuda(), proj_mats.unsqueeze(0).cuda(), init_depth_min, depth_interval)
    torch.cuda.synchronize()
print('inference time', time.time()-t)

In [None]:
plt.imshow(unpreprocess(imgs[0]).permute(1,2,0))
plt.imshow(visualize_depth(results['depth_0'][0]).permute(1,2,0), alpha=0.5)

In [None]:
plt.imshow(results['confidence_2'][0].cpu().numpy()>0.999)

# Reference: show pixels whose absolute depth error is less than 2mm

In [None]:
err2 = torch.abs(depths['level_0']-results['depth_0'].cpu())[0]<2
plt.imshow(err2);
print('acc_2mm :', ((err2.float()*masks['level_0']).sum()/masks['level_0'].sum()).item())

FLOPs example

In [None]:
import torch
from ptflops import get_model_complexity_info

# Load the model checkpoint
model = MyModel()
model.load_state_dict(torch.load('path/to/checkpoint.pth'))

# Calculate FLOPs
with torch.cuda.device(0):
    macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
                                             print_per_layer_stat=True, verbose=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
