In [1]:
import numpy as np
import os
import pickle
import gc
import re
import cv2
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf

from PIL import Image
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import clear_output

from tensorflow.keras import layers, Model
from tensorflow.keras.utils import Sequence
from tensorflow.keras.layers import Layer, Conv2D, Conv3D, Conv3DTranspose, BatchNormalization, ReLU
from tensorflow.python.keras.layers.convolutional import Conv3DTranspose
from tensorflow.python.ops.init_ops_v2 import he_normal
from keras import regularizers
from keras.callbacks import EarlyStopping, ModelCheckpoint

# from tensorflow.keras import mixed_precision
print("TensorFlow version: ", tf.__version__)

tf.get_logger().setLevel('ERROR')           # Suppress TensorFlow logging (2)

# Enable GPU dynamic memory allocation
gpus = tf.config.experimental.list_physical_devices('GPU')

for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

TensorFlow version:  2.10.0


In [2]:
def _clamp_disp(disp, min_disp, max_disp):
    """Clip max disparity, ortherwise it'll be hard for network to learn really big disparity/close object"""
    return np.clip(disp, min_disp, max_disp)

def _mean_std(img):
    img = np.array(img, dtype=np.float32) / 255.0
    img[:, :, 0] -= 0.485
    img[:, :, 0] /= 0.229
    img[:, :, 1] -= 0.456
    img[:, :, 1] /= 0.224
    img[:, :, 2] -= 0.406
    img[:, :, 2] /= 0.225
    
    return img

def StereoDataloader(img_left, img_right, disp_left, img_h, img_w, df_h, df_w, batch_num, ComPerBatch, data_ord, max_disp):
    tmp_img = []
    tmp_disp = []
    if df_h > 32 :
        randomH = np.random.randint(0, df_h)
    else :
        randomH = 0
    if df_w > 32 :
        randomW = np.random.randint(0, df_w)
    else :
        randomW = 0
    
    
    for idx in range(batch_num*ComPerBatch,(batch_num+1)*ComPerBatch):
        l = np.array(Image.open(img_left[data_ord[idx]]))[randomH:randomH+img_h, randomW:randomW+img_w, :]
        r = np.array(Image.open(img_right[data_ord[idx]]))[randomH:randomH+img_h, randomW:randomW+img_w, :]
        tmp_img.append(np.concatenate((_mean_std(l), _mean_std(r)), axis=2))
        
        dispL = np.loadtxt(disp_left[data_ord[idx]], delimiter=",", dtype=np.float32)[randomH:randomH+img_h, randomW:randomW+img_w]
        tmp_disp.append(_clamp_disp(dispL,0,max_disp))
    
    return np.array(tmp_img, dtype=np.float32), np.array(tmp_disp, dtype=np.float32)


In [3]:
def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):
    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(out_planes, kernel_size=kernel_size, strides=stride, padding='valid' if pad == 0 else 'same', dilation_rate=dilation, use_bias=False),
        tf.keras.layers.BatchNormalization()
    ])

def convbn_3d(in_planes, out_planes, kernel_size, stride, pad):
    return tf.keras.Sequential([
        tf.keras.layers.Conv3D(out_planes, kernel_size=kernel_size, strides=stride, padding='valid' if pad == 0 else 'same', dilation_rate=dilation, use_bias=False),
        tf.keras.layers.BatchNormalization()
    ])

class BasicBlock(tf.keras.Model):
    expansion = 1
    def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
        super(BasicBlock, self).__init__()
        self.conv1 = tf.keras.Sequential([
            convbn(inplanes, planes, 3, stride, pad, dilation),
            tf.keras.layers.ReLU()
        ])
        self.conv2 = convbn(planes, planes, 3, 1, pad, dilation)
        self.downsample = downsample
        self.stride = stride

    @tf.function
    def call(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample is not None:
            x = self.downsample(x)
        out += x
        return out

class SoftArgMin(Layer):
    def __init__(self, maxdisp):
        super(SoftArgMin, self).__init__()        
        self.disp_indices = tf.range(maxdisp, dtype=tf.float32)
        self.disp_indices = tf.reshape(self.disp_indices, [1, maxdisp, 1, 1])

    def call(self, x):
        # [N, D, H, W] 
        x = tf.nn.softmax(x, axis=1)  # compute softmax over all disparity
        x = tf.math.multiply(x, self.disp_indices)
        # [N, D, H, W] -> [N, H, W]
        x = tf.math.reduce_sum(x, axis=1)
        
        return x
    
class feature_extraction(tf.keras.Model):
    def __init__(self):
        super(feature_extraction, self).__init__()
        self.inplanes = 32
        self.firstconv = tf.keras.Sequential([
            convbn(3, 32, 3, 2, 1, 1),
            tf.keras.layers.ReLU(),
            convbn(32, 32, 3, 1, 1, 1),
            tf.keras.layers.ReLU(),
            convbn(32, 32, 3, 1, 1, 1),
            tf.keras.layers.ReLU()
        ])
        self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1)
        self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1)
        self.layer3 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1)
        self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 2)
        self.branch1 = tf.keras.Sequential([
            tf.keras.layers.AveragePooling2D(pool_size=(64, 64), strides=(64, 64)),
            convbn(128, 32, 1, 1, 0, 1),
            tf.keras.layers.ReLU()
        ])
        self.branch2 = tf.keras.Sequential([
            tf.keras.layers.AveragePooling2D(pool_size=(32, 32), strides=(32, 32)),
            convbn(128, 32, 1, 1, 0, 1),
            tf.keras.layers.ReLU()
        ])
        self.branch3 = tf.keras.Sequential([
            tf.keras.layers.AveragePooling2D(pool_size=(16, 16), strides=(16, 16)),
            convbn(128, 32, 1, 1, 0, 1),
            tf.keras.layers.ReLU()
        ])
        self.branch4 = tf.keras.Sequential([
            tf.keras.layers.AveragePooling2D(pool_size=(8, 8), strides=(8, 8)),
            convbn(128, 32, 1, 1, 0, 1),
            tf.keras.layers.ReLU()
        ])
        self.lastconv = tf.keras.Sequential([
            convbn(320, 128, 3, 1, 1, 1),
            tf.keras.layers.ReLU(),
            tf.keras.layers.Conv2D(32, kernel_size=1, padding='valid', strides=1, use_bias=False)
        ])
    
    def _make_layer(self, block, planes, blocks, stride, pad, dilation):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = tf.keras.Sequential([
                tf.keras.layers.Conv2D(planes * block.expansion, kernel_size=1, strides=stride, use_bias=False),
                tf.keras.layers.BatchNormalization()
            ])
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, 1, None, pad, dilation))
        
        return tf.keras.Sequential(layers)
    
    @tf.function
    def call(self, x):
        output = self.firstconv(x)
        output = self.layer1(output)
        output_raw = self.layer2(output)
        output = self.layer3(output_raw)
        output_skip = self.layer4(output)
        
        output_branch1 = self.branch1(output_skip)
        output_branch1 = tf.image.resize(output_branch1, (output_skip.shape[1], output_skip.shape[2]), method=tf.image.ResizeMethod.BILINEAR)
        output_branch2 = self.branch2(output_skip)
        output_branch2 = tf.image.resize(output_branch2, (output_skip.shape[1], output_skip.shape[2]), method=tf.image.ResizeMethod.BILINEAR)
        output_branch3 = self.branch3(output_skip)
        output_branch3 = tf.image.resize(output_branch3, (output_skip.shape[1], output_skip.shape[2]), method=tf.image.ResizeMethod.BILINEAR)
        output_branch4 = self.branch4(output_skip)
        output_branch4 = tf.image.resize(output_branch4, (output_skip.shape[1], output_skip.shape[2]), method=tf.image.ResizeMethod.BILINEAR)
        output_feature = tf.concat((output_raw, output_skip, output_branch4, output_branch3, output_branch2, output_branch1), axis=3)
        output_feature = self.lastconv(output_feature)
        return output_feature

In [4]:
class hourglass(Model):
    def __init__(self, inplanes):
        super(hourglass, self).__init__()

        self.conv1 = tf.keras.Sequential([
            Conv3D(inplanes*2, kernel_size=3, strides=2, padding='same'),
            BatchNormalization(),
            ReLU()
        ])

        self.conv2 = Conv3D(inplanes*2, kernel_size=3, strides=1, padding='same')

        self.conv3 = tf.keras.Sequential([
            Conv3D(inplanes*2, kernel_size=3, strides=2, padding='same'),
            BatchNormalization(),
            ReLU()
        ])

        self.conv4 = tf.keras.Sequential([
            Conv3D(inplanes*2, kernel_size=3, strides=1, padding='same'),
            BatchNormalization(),
            ReLU()
        ])

        self.conv5 = tf.keras.Sequential([
            Conv3DTranspose(inplanes*2, kernel_size=3, strides=2, padding='same'),
            BatchNormalization()
        ])

        self.conv6 = tf.keras.Sequential([
            Conv3DTranspose(inplanes, kernel_size=3, strides=2, padding='same'),
            BatchNormalization()
        ])
    
    @tf.function
    def call(self, x, presqu=None, postsqu=None):
        out = self.conv1(x)
        pre = self.conv2(out)
        if postsqu is not None:
            pre = tf.nn.relu(pre + postsqu)
        else:
            pre = tf.nn.relu(pre)

        out = self.conv3(pre)
        out = self.conv4(out)

        if presqu is not None:
            post = tf.nn.relu(self.conv5(out) + presqu)
        else:
            post = tf.nn.relu(self.conv5(out) + pre)

        out = self.conv6(post)

        return out, pre, post

class PSMNet(Model):
    def __init__(self, maxdisp):
        super(PSMNet, self).__init__()
        self.maxdisp = maxdisp
        self.feature_extraction = feature_extraction()

        self.dres0 = tf.keras.Sequential([
            Conv3D(32, kernel_size=3, padding='same'), ReLU(),
            Conv3D(32, kernel_size=3, padding='same'), ReLU()
        ])

        self.dres1 = tf.keras.Sequential([
            Conv3D(32, kernel_size=3, padding='same'), ReLU(),
            Conv3D(32, kernel_size=3, padding='same')
        ])

        self.dres2 = hourglass(32)
        self.dres3 = hourglass(32)
        self.dres4 = hourglass(32)

        self.classif1 = tf.keras.Sequential([
            Conv3D(32, kernel_size=3, padding='same'), ReLU(),
            Conv3D(1, kernel_size=3, padding='same')
        ])

        self.classif2 = tf.keras.Sequential([
            Conv3D(32, kernel_size=3, padding='same'), ReLU(),
            Conv3D(1, kernel_size=3, padding='same')
        ])

        self.classif3 = tf.keras.Sequential([
            Conv3D(32, kernel_size=3, padding='same'), ReLU(),
            Conv3D(1, kernel_size=3, padding='same')
        ])
        
        self.soft_argmin = SoftArgMin(self.maxdisp)

    @tf.function
    def call(self, data, training):
        left = data[:,:,:,0:3]
        right = data[:,:,:,3:6]        
        refimg_fea = self.feature_extraction(left)
        targetimg_fea = self.feature_extraction(right)
        refimg_fea = tf.expand_dims(refimg_fea, axis=1)
        targetimg_fea = tf.expand_dims(targetimg_fea, axis=1)
        result = []
        # Cost Volume
        cost = []        
        for d in range(max_disp // 4):
            if d > 0:
                left_shift = tf.pad(refimg_fea[:, :, :, d:, :], paddings=[[0, 0], [0, 0], [0, 0], [d, 0], [0, 0]])#, mode='SYMMETRIC')
                right_shift = tf.pad(targetimg_fea[:, :, :, :-d, :], paddings=[[0, 0], [0, 0], [0, 0], [d, 0], [0, 0]])#, mode='SYMMETRIC')
                cost_plate = tf.concat([left_shift, right_shift], axis=4)
            else:                
                cost_plate = tf.concat([refimg_fea, targetimg_fea], axis=4)
            
            cost.append(cost_plate)
        
        cost = tf.concat(cost, axis=1)
        cost = tf.transpose(cost, perm=[0, 3, 2, 1, 4])        
        
        del right
        del refimg_fea
        del targetimg_fea
        
        # PSMNet Regularization
        cost0 = self.dres0(cost)
        cost0 = self.dres1(cost0) + cost0
        
        out1, pre1, post1 = self.dres2(cost0)
        out1 = out1 + cost0

        out2, pre2, post2 = self.dres3(out1, pre1, post1)
        out2 = out2 + cost0

        out3, pre3, post3 = self.dres4(out2, pre1, post2)
        out3 = out3 + cost0

        cost1 = self.classif1(out1)
        cost2 = self.classif2(out2) + cost1
        cost3 = self.classif3(out3) + cost2
        cost1 = tf.squeeze(cost1, axis=-1)
        cost2 = tf.squeeze(cost2, axis=-1)
        cost3 = tf.squeeze(cost3, axis=-1)
        
        cost3 = tf.image.resize(cost3, (left.shape[2], left.shape[1]), method=tf.image.ResizeMethod.BILINEAR)
        cost3 = tf.image.resize(tf.transpose(cost3, perm=[0, 3, 2, 1]), (self.maxdisp, left.shape[1]), method=tf.image.ResizeMethod.BILINEAR)
        pred3 = self.soft_argmin(cost3)        
        
        if not net.trainable :
            return pred3
        else :
            cost1 = tf.image.resize(cost1, (left.shape[2], left.shape[1]), method=tf.image.ResizeMethod.BILINEAR)
            cost1 = tf.image.resize(tf.transpose(cost1, perm=[0, 3, 2, 1]), (self.maxdisp, left.shape[1]), method=tf.image.ResizeMethod.BILINEAR)
            pred1 = self.soft_argmin(cost1)
            cost2 = tf.image.resize(cost2, (left.shape[2], left.shape[1]), method=tf.image.ResizeMethod.BILINEAR)
            cost2 = tf.image.resize(tf.transpose(cost2, perm=[0, 3, 2, 1]), (self.maxdisp, left.shape[1]), method=tf.image.ResizeMethod.BILINEAR)            
            pred2 = self.soft_argmin(cost2)
            result = tf.concat([tf.expand_dims(pred1, axis=3), tf.expand_dims(pred2, axis=3), tf.expand_dims(pred3, axis=3)], axis=3)
            return result
        

def smooth_l1_loss(y_true, y_pred, beta=1.0, size_average=True, reduction='mean'):
    abs_diff = tf.abs(y_true - y_pred)
    squar_loss = tf.square(tf.minimum(abs_diff, beta))
    linear_loss = tf.maximum(abs_diff - beta, 0.0)
    loss = squar_loss + linear_loss

    if size_average:
        if reduction == 'mean':
            loss = tf.reduce_mean(loss)
        elif reduction == 'sum':
            loss = tf.reduce_sum(loss)
        else:
            raise ValueError(f"Invalid reduction type: {reduction}")
    return loss

def MyLoss(y_true, y_pred):    
    output1 = y_pred[:,:,:,0]
    output2 = y_pred[:,:,:,1]
    output3 = y_pred[:,:,:,2]
    disp_true = y_true
    mask = disp_true < 160
    mask = tf.cast(mask, dtype=tf.float32)
    
    weight1 = 0.5
    weight2 = 0.7
    weight3 = 1.0

    loss = weight1 * smooth_l1_loss(disp_true*mask, output1*mask, size_average=True) + \
           weight2 * smooth_l1_loss(disp_true*mask, output2*mask, size_average=True) + \
           weight3 * smooth_l1_loss(disp_true*mask, output3*mask, size_average=True)
    
    return loss

In [5]:
#### Loading dataset path
finl = open('gh_imgl.pkl','rb')
finr = open('gh_imgr.pkl', 'rb')
img_left = pickle.load(finl)
img_right = pickle.load(finr)
finl.close()
finr.close()

finl = open('gh_disp.pkl', 'rb')
disp_left = pickle.load(finl)
finl.close()

#training/test data set
tot_num1 = len(img_left)//2 # Half number, training data
tot_num2 = len(img_left) # total number, training+test data
file_max = tot_num2 # if training, file_max = tot_num1. Elif testing performance,file_max = tot_num2. 

# data size
max_w = 984
max_h = 560
max_disp = 160

img_w = 960
img_h = 544
df_w = max_w - img_w
df_h = max_h - img_h

ComPerBatch = 1 #24 #256 # component number per batch
batch_size = np.int32(tot_num2/ComPerBatch) # number of batch
totpx_batch = img_w*img_h*ComPerBatch

# if you want to save data, save_data = 1,
#  else save_data = 0
save_data = 1

# Image save function
def disp_img(img, title): 
    plt.figure(figsize=(5, 3))
    plt.title(title)
    plt.axis('off')
    plt.imshow(img, cmap='jet')
    plt.colorbar(shrink=0.5) 
    plt.tight_layout()
    plt.show()
    
    plt.close()
    gc.collect()

vmin = 40
vmax = 80
vgap = 10
# Image save function
def _save_img(img, title, dest, file_name, data_ind):
#     plt.switch_backend('Agg')
    plt.rc('font', size=30)        # 기본 폰트 크기
    plt.rcParams["font.family"] = "Times New Roman"
    
    plt.figure(figsize=(10, 6))
    plt.title(title)
    plt.axis('off')
    ax = plt.gca()
    im = ax.imshow(img, cmap='jet', vmin=vmin, vmax=vmax)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = plt.colorbar(im, cax=cax)
    cbar.set_ticks(np.arange(vmin, vmax, vgap))  # 32 간격으로 눈금 설정    
    plt.tight_layout()
    # plt.show()
    plt.savefig('%s/%s_%05d.png' %(dest,file_name,data_ind))
    
    plt.close()    
    gc.collect()

In [6]:
# PSMNet performance #############################################################################
max_epochs1 = 196
chkdir='saved_model/PSMNet_iter%d/' %(max_epochs1)

net = PSMNet(maxdisp=max_disp)
net.compile(optimizer=tf.keras.optimizers.Adam(), loss=MyLoss)
net.trainable = True

batch_n = 0
mix_set = np.arange(file_max) #if training : mix_set = np.random.permutation(file_max)

DB_imgs,DB_disp = StereoDataloader(
    img_left = img_left,
    img_right = img_right,
    disp_left = disp_left,                
    img_h = img_h,
    img_w = img_w,
    df_h = df_h,
    df_w = df_w,
    batch_num = batch_n, # order of batch
    ComPerBatch = ComPerBatch, # component number of a batch
    data_ord = mix_set,
    max_disp = max_disp
)
    
net.fit(DB_imgs, DB_disp, batch_size=1, epochs=1, verbose=0)
net.load_weights(chkdir)
net.trainable = False #True

dest_dir='PSMNet_iter%d' %(max_epochs1)

if not os.path.isdir(dest_dir):
    os.makedirs(dest_dir)


file_name = '%s/test.csv' %(dest_dir)
f = open(file_name, 'w')    
f.write('Data No./Number as Disparity Difference, PSMNet_GH2470_epc%d loss, ratio of px<=1, <=2, <=3\n' %(max_epochs1))
f.close()

for batch_n in range(batch_size):
    DB_imgs,DB_disp = StereoDataloader(
        img_left = img_left,
        img_right = img_right,
        disp_left = disp_left,                
        img_h = img_h,
        img_w = img_w,
        df_h = df_h,
        df_w = df_w,
        batch_num = batch_n, # order of batch
        ComPerBatch = ComPerBatch, # component number of a batch
        data_ord = mix_set,
        max_disp = max_disp
    )

    # Model
    if save_data == 0:
        print('Batch %d/%d' %(batch_n+1,batch_size))
        result = net.predict(DB_imgs, batch_size=1, verbose=1)
        result = _clamp_disp(result,0, max_disp)
    elif save_data == 1:
        print('\rBatch %d/%d' %(batch_n+1,batch_size), end='')
        result = _clamp_disp(net.predict(DB_imgs, batch_size=1, verbose=0),0, max_disp)
        loss1 = np.mean(np.abs(DB_disp-result))

        file_name = '%s/test.csv' %(dest_dir)
        f = open(file_name, 'a')

        disp_diff1 = np.abs(DB_disp-result)

        f.write('%d, ' %(batch_n+1))

        # Model disparity error and px area ratio (<= 1px,2px,3px)        
        f.write('%.5f, ' %loss1)        
        f.write('%.5f, ' %(np.sum(disp_diff1<=1)/totpx_batch))        
        f.write('%.5f, ' %(np.sum(disp_diff1<=2)/totpx_batch))        
        f.write('%.5f \n'%(np.sum(disp_diff1<=3)/totpx_batch))

        f.close()

        #### Disparity calculated by Trueth Data            
        title = 'Ground Truth_Data#%d' %(batch_n+1)
        file_name = 'True_Disp'
        _save_img(DB_disp[0,:,:], title, dest_dir, file_name, batch_n+1)

        #### Disparity calculated by Stereo Camera Data with Model
        title = 'PSMNet_Data#%d' %(batch_n+1)
        file_name = 'Stereo_Disp'
        _save_img(result[0,:,:], title, dest_dir, file_name, batch_n+1)

        del file_name
        del loss1
        del disp_diff1
        gc.collect()

    del DB_imgs
    del DB_disp
    del result
    gc.collect()
    
del net

tf.keras.backend.clear_session()
gc.collect()

Batch 2470/2470

48388