In [1]:

import torchvision.transforms as transforms
from dataset.sceneflowdataloader import SceneFlowLoader
import torch

from glob import glob
from torch.utils.data import Dataset
import cv2
import numpy as np
import os
from tqdm import tqdm

import torch.nn.parallel
import torch.utils.data
import torchvision.transforms as transforms
from torchvision.utils import make_grid

import superpixelnet.flow_transforms as flow_transforms
from superpixelnet.models.Spixel_single_layer import SpixelNet1l_bn
from superpixelnet.loss import compute_semantic_pos_loss
import datetime

from superpixelnet.train_util import *

# psmnet
from  models import *
from models.submodule import disparityregression

import matplotlib.pyplot as plt

from PIL import Image
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure


In [2]:
def imgtensor2np(img):
    return img.permute(1,2,0).detach().cpu().numpy()
def to_device(args, device):
    args_out = []
    for arg in args:
        if isinstance(arg, list):
            arg = [ elem.to(device) for elem in arg ]
        else:
            arg = arg.to(device)
        args_out.append(arg)
    return args_out

class ARG:
    def __init__(self):
        self.dataset = 'SceneFlowPSM'
        self.arch = 'SpixelNet1l_bn'
        # self.data= './data_preprocessing/Data'; self.data= './NYU'
        self.data = "./dataset/Monkaa"
        self.savepath = './checkpoints'
        # self.train_img_height = 256; self.train_img_width= 512 
        self.train_img_height = 128; self.train_img_width= 256
        self.input_img_height = self.train_img_height
        self.input_img_width = self.train_img_width
        
        self.workers = 4; self.epochs = 10  *10000
        self.start_epoch = 0; self.epoch_size = 6000; self.batch_size = 1;
        self.solver = 'adam'; self.lr= 0.0005; # 0.000005 
        self.momentum = 0.9; self.beta = 0.999; self.weight_decay=4e-4;self.bias_decay=0
        self.milestones=[200000]; self.additional_step=100000; 
        self.pos_weight = 0.003; self.downsize = 16;
        self.gpu = '0'; self.print_freq = 10; self.record_freq  = 5; self.label_factor=5; self.pretrained = "./checkpoints/SceneFlowPSM/SpixelNet1l_bn_adam_100000epochs_epochSize6000_b8_lr0.0005_posW0.003_/sfcn/model_best.tar";
        self.no_date=True
        self.maxdisp = 192; self.psmmodel = 'basic'
        self.seed = 1; self.pretrainedpsmnet = "./checkpoints/SceneFlowPSM/SpixelNet1l_bn_adam_100000epochs_epochSize6000_b8_lr0.0005_posW0.003_/psm/model_best.tar"

args = ARG()

# !----- NOTE the current code does not support cpu training -----!
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    print('Current code does not support CPU training! Sorry about that.')
    exit(1)

N_CLASSES = 50
def foldercheck(save_path):
    if not os.path.exists(save_path):
        os.makedirs(save_path)


In [3]:
# ==========  Data loading code ==============
input_transform = transforms.Compose([
    flow_transforms.ArrayToTensor(),
    transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
    transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
])

val_input_transform = transforms.Compose([
    flow_transforms.ArrayToTensor(),
    transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
    transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
])

target_transform = transforms.Compose([
    flow_transforms.ArrayToTensor(),
])

import numbers


co_transform = flow_transforms.Compose([
        flow_transforms.Resize((args.train_img_height ,args.train_img_width))
    ])


dataset = SceneFlowLoader(args.data, mode='eating_camera2_x2', transform=input_transform, target_transform=target_transform, co_transform=co_transform)

Total classes 47.
Number of eating_camera2_x2 samples : 151


In [4]:
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=args.batch_size,
    num_workers=args.workers, shuffle=False)


In [5]:
# ============== Load S-FCN model ====================
if args.pretrained:
    network_data = torch.load(args.pretrained)
    args.arch = network_data['arch']
    print("=> using pre-trained model '{}'".format(args.arch))
    print(f"was trained until {network_data['epoch']} epochs")
else:
    network_data = None
    print("=> creating model '{}'".format(args.arch))

model = SpixelNet1l_bn( data = network_data).cuda()
model = torch.nn.DataParallel(model).cuda()


# ============== Load PSMNet model ====================
if args.psmmodel == 'stackhourglass':
    psmnet = stackhourglass(args.maxdisp, slicmode=False)
elif args.psmmodel == 'basic':
    psmnet = basic(args.maxdisp)
else:
    print('no model')
psmnet = psmnet.to(device)

if args.pretrainedpsmnet is not None:
    print('Load pretrained model')
    pretrain_dict = torch.load(args.pretrainedpsmnet)
    psmnet.load_state_dict(pretrain_dict['state_dict'])

psmnet = torch.nn.DataParallel(psmnet)
print('Number of PSM model parameters: {}'.format(sum([p.data.nelement() for p in psmnet.parameters()])))

# XY_feat: the coordinate feature for position loss term
spixelID, XY_feat_stack = init_spixel_grid(args)


=> using pre-trained model 'SpixelNet1l_bn'
was trained until 53 epochs
Load pretrained model
Number of PSM model parameters: 3287360


In [None]:
def computeOutput(cost, Ql, S):
    upsampled_cost = upfeat(cost, Ql, S, S)
    assert upsampled_cost.size()[2] == Ql.size()[2]
    assert upsampled_cost.size()[3] == Ql.size()[3]
    upsampled_cost = F.upsample(upsampled_cost.unsqueeze(1), [args.maxdisp,Ql.size()[2],Ql.size()[3]], mode='trilinear').squeeze(1)
    output = F.softmax(upsampled_cost)
    output = disparityregression(args.maxdisp)(output)
    return output
foldercheck(f'{args.savepath}/disparity_outputs/')
with torch.no_grad():
    psmnet.eval(); model.eval()
    for i, sample in tqdm(enumerate(data_loader)):
        imL, imR, label, labelR, disp_true = to_device(sample, device)

        # ========== complete data loading ================
        label_1hot = label2one_hot_torch(label, C=N_CLASSES) # set C=50 as SSN does
        LABXY_feat_tensor = build_LABXY_feat(label_1hot, XY_feat_stack)  # B* (50+2 )* H * W
        label_1hotR = label2one_hot_torch(labelR, C=N_CLASSES) # set C=50 as SSN does
        LABXY_feat_tensorR = build_LABXY_feat(label_1hotR, XY_feat_stack)  # B* (50+2 )* H * W

        # ========== predict association map ============
        Ql = model(imL)
        Qr = model(imR)

        # ========== compute disparity map ============
        mask = disp_true < args.maxdisp
        mask.detach_()
        #----
        S = 4; m = 0.003
        pooled_imL = poolfeat(imL, Ql.clone(), S, S)
        pooled_imR = poolfeat(imR, Qr.clone(), S, S)

        if args.psmmodel == 'stackhourglass':
            cost1, cost2, cost3 = psmnet(pooled_imL, pooled_imR)
            # output1 = torch.squeeze(output1,1); output2 = torch.squeeze(output2,1); output3 = torch.squeeze(output3,1)
            # psmloss = 0.5*F.smooth_l1_loss(output1[mask], disp_true[mask], size_average=True) + \
            #     0.7*F.smooth_l1_loss(output2[mask], disp_true[mask], size_average=True) + \
            #         F.smooth_l1_loss(output3[mask], disp_true[mask], size_average=True) 
        elif args.psmmodel == 'basic':
            cost = psmnet(pooled_imL,pooled_imR)
            output  = computeOutput(cost, Ql, S)
            # output = torch.squeeze(output,1)
            psmloss = F.smooth_l1_loss(output[mask], disp_true[mask], size_average=True)

        # ========== Visualization ============
        mean_values = torch.tensor([0.411, 0.432, 0.45], dtype=imL.dtype).view(3, 1, 1)
        input_l_save = (make_grid((imL.detach().cpu() + mean_values).clamp(0, 1), nrow=args.batch_size))
        label_save = make_grid(args.label_factor * label)
        curr_spixl_map = update_spixl_map(spixelID,Ql)
        spixel_lab_save = make_grid(curr_spixl_map, nrow=args.batch_size)[0, :, :]
        spixel_viz, _ = get_spixel_image(input_l_save, spixel_lab_save)

        
        mask = disp_true < 192
        mask = imgtensor2np(mask[0]).squeeze().astype(np.uint8)

        pred_disp = (imgtensor2np(output[0])[..., 0]).astype(np.uint8)
        disparity = (imgtensor2np(disp_true[0])[..., 0]).astype(np.uint8)
        leftim = (imgtensor2np(input_l_save)* 256).astype(np.uint8)

        pred_disp *= mask
        disparity *= mask


        f, plts = plt.subplots(1,3, figsize=(10,5), dpi=300)
        f.tight_layout()

        # To remove the huge white borders

        canvas = FigureCanvasAgg(f)
        plts[0].imshow(leftim)
        plts[0].set_title('Image')
        plts[0].axis("off")  ; plts[0].axis("tight") ; plts[0].axis("image")


        plts[1].imshow(pred_disp)
        plts[1].set_title('Predicted Disparities')
        plts[1].axis("off")  ; plts[1].axis("tight") ; plts[1].axis("image")


        plts[2].imshow(disparity)
        plts[2].set_title('Groundtruth Disparities')
        plts[2].axis("off")  ; plts[2].axis("tight") ; plts[2].axis("image")

        canvas.draw()
        s, (width, height) = canvas.print_to_buffer()
        # Option 2a: Convert to a NumPy array.
        X = np.fromstring(s, np.uint8).reshape((height, width, 4))
        output = X[..., :3][450:950, 100:-50]

        cv2.imwrite(f'{args.savepath}/disparity_outputs/{str(i).zfill(5)}.png', output)




In [7]:
files=sorted(glob(f'{args.savepath}/disparity_outputs/*png'))
print(len(files))
for i, file in tqdm(enumerate(files)):
    img = cv2.imread(file)
    if(i==0):
        foldercheck(f'{args.savepath}/disparity_video')
        h,w = img.shape[:2]    
        writer = cv2.VideoWriter(f'{args.savepath}/disparity_video/output1.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), 10, (w, h))
    # out=cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
    writer.write(img)
writer.release()
print("Done stitching")


0it [00:00, ?it/s]

151


151it [00:37,  4.06it/s]

Done stitching



