In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')
%cd gdrive/MyDrive/CS498_project/PSMNet/

Mounted at /content/gdrive/
/content/gdrive/MyDrive/CS498_project/PSMNet


In [None]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from models.submodule import *
from models.stackhourglass import hourglass


import time
import math
import random
import torch
from torchvision import transforms as transforms
import numpy as np
from PIL import Image, ImageOps
from torchvision.models import resnet18
import pdb

import torch.optim as optim
from tqdm import tqdm

In [None]:
root_path = '/content/gdrive/MyDrive/CS498_project/PSMNet/'

In [None]:

from dataloader import preprocess
from pathlib import Path


class KITTI(torch.utils.data.Dataset):
    def __init__(self, data_path, mode):


        left_fold  = 'colored_0/'
        right_fold = 'colored_1/'
        disp   = 'disp_occ/'

        self.mode = mode
        image_paths = np.array([[str(img),str(img).replace(left_fold, right_fold),str(img).replace(left_fold, disp)] for img in Path(data_path+left_fold).glob("*_10.*")])
        
        np.random.shuffle(image_paths)


        split_point = int(len(image_paths)*0.85)

        data_split = {'train':image_paths[:split_point],
                      'val': image_paths[split_point:-5],
                      'test': image_paths[-5:]}



        self.left  =  data_split[self.mode][:,0]
        self.right = data_split[self.mode][:,1]
        self.disp_true = data_split[self.mode][:,2]


    def __getitem__(self, index):

        left_path  = self.left[index]
        right_path = self.right[index]
        disp_true_path = self.disp_true[index]

        left_img = Image.open(left_path).convert('RGB')
        right_img = Image.open(right_path).convert('RGB')
        disp_true = Image.open(disp_true_path)

   

        if self.mode == 'train':  

           w, h = left_img.size
           th, tw = 256, 512
 
           x1 = random.randint(0, w - tw)
           y1 = random.randint(0, h - th)

           left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
           right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

           disp_true = np.ascontiguousarray(disp_true,dtype=np.float32)/256
           disp_true = disp_true[y1:y1 + th, x1:x1 + tw]

           processed = preprocess.get_transform(augment=False)  
           left_img   = processed(left_img)
           right_img  = processed(right_img)

           return left_img, right_img, disp_true
        else:
           w, h = left_img.size

           left_img = left_img.crop((w-1232, h-368, w, h))
           right_img = right_img.crop((w-1232, h-368, w, h))
           w1, h1 = left_img.size

           disp_true = disp_true.crop((w-1232, h-368, w, h))
           disp_true = np.ascontiguousarray(disp_true,dtype=np.float32)/256

           processed = preprocess.get_transform(augment=False)  
           left_img       = processed(left_img)
           right_img      = processed(right_img)

           return left_img, right_img, disp_true

    def __len__(self):
        return len(self.disp_true)

In [None]:
class ResNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(ResNetFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(512, 256, kernel_size=3, padding=1, stride=1)
        self.conv1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.upsample1 = nn.Upsample(size=(192, 616), mode='bilinear', align_corners=True)

        self.conv2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.upsample2 = nn.Upsample(size=(192, 616), mode='bilinear', align_corners=True)

        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.upsample3 = nn.Upsample(size=(96, 308), mode='bilinear', align_corners=True)

        self.conv4 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):

        x = self.conv1(x)
        x = self.upsample1(x)
        x = self.conv2(x)
        x = self.upsample2(x)
        x = self.conv3(x)
        x = self.upsample3(x)
        x = self.conv4(x)
        return x

In [None]:
class PSMNet(nn.Module):
    def __init__(self, maxdisp, transfer_learning = False):
        super(PSMNet, self).__init__()
        self.maxdisp = maxdisp

        if transfer_learning:
          
          self.feature_extraction = ResNetFeatureExtractor()

        else:
          self.feature_extraction = feature_extraction()


        self.dres0 = nn.Sequential(convbn_3d(64, 32, 3, 1, 1),
                                     nn.ReLU(inplace=True),
                                     convbn_3d(32, 32, 3, 1, 1),
                                     nn.ReLU(inplace=True))

        self.dres1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                   nn.ReLU(inplace=True),
                                   convbn_3d(32, 32, 3, 1, 1)) 

        self.dres2 = hourglass(32)

        self.dres3 = hourglass(32)

        self.dres4 = hourglass(32)

        self.classif1 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))

        self.classif2 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))

        self.classif3 = nn.Sequential(convbn_3d(32, 32, 3, 1, 1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv3d(32, 1, kernel_size=3, padding=1, stride=1,bias=False))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv3d):
                n = m.kernel_size[0] * m.kernel_size[1]*m.kernel_size[2] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()


    def forward(self, left, right, resnet_L_fea, resnet_R_fea):

        if transfer_learning:

          refimg_fea = self.feature_extraction(resnet_L_fea)
          targetimg_fea = self.feature_extraction(resnet_R_fea)

        else:

          refimg_fea     = self.feature_extraction(left)
          targetimg_fea  = self.feature_extraction(right)

        #matching
        cost = torch.FloatTensor(refimg_fea.size()[0], refimg_fea.size()[1]*2, self.maxdisp//4,  refimg_fea.size()[2],  refimg_fea.size()[3]).zero_().cuda()

        for i in range(self.maxdisp//4):
            if i > 0 :
             cost[:, :refimg_fea.size()[1], i, :,i:]   = refimg_fea[:,:,:,i:]
             cost[:, refimg_fea.size()[1]:, i, :,i:] = targetimg_fea[:,:,:,:-i]
            else:
             cost[:, :refimg_fea.size()[1], i, :,:]   = refimg_fea
             cost[:, refimg_fea.size()[1]:, i, :,:]   = targetimg_fea
        cost = cost.contiguous()

        cost0 = self.dres0(cost)
        cost0 = self.dres1(cost0) + cost0

        out1, pre1, post1 = self.dres2(cost0, None, None) 
        out1 = out1+cost0

        out2, pre2, post2 = self.dres3(out1, pre1, post1) 
        out2 = out2+cost0

        out3, pre3, post3 = self.dres4(out2, pre1, post2) 
        out3 = out3+cost0

        cost1 = self.classif1(out1)
        cost2 = self.classif2(out2) + cost1
        cost3 = self.classif3(out3) + cost2

        if self.training:
            cost1 = F.upsample(cost1, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')
            cost2 = F.upsample(cost2, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')

            cost1 = torch.squeeze(cost1,1)
            pred1 = F.softmax(cost1,dim=1)
            pred1 = disparityregression(self.maxdisp)(pred1)

            cost2 = torch.squeeze(cost2,1)
            pred2 = F.softmax(cost2,dim=1)
            pred2 = disparityregression(self.maxdisp)(pred2)

        cost3 = F.upsample(cost3, [self.maxdisp,left.size()[2],left.size()[3]], mode='trilinear')
        cost3 = torch.squeeze(cost3,1)
        pred3 = F.softmax(cost3,dim=1)
        #For your information: This formulation 'softmax(c)' learned "similarity" 
        #while 'softmax(-c)' learned 'matching cost' as mentioned in the paper.
        #However, 'c' or '-c' do not affect the performance because feature-based cost volume provided flexibility.
        pred3 = disparityregression(self.maxdisp)(pred3)

        if self.training:
            return pred1, pred2, pred3
        else:
            return pred3

In [None]:
def get_model(maximum_disp,transfer_learning,model_weights):

  model = PSMNet(maximum_disp,transfer_learning)
  
  model = nn.DataParallel(model, device_ids=[0]) #although I fine-tuned on 1 gpu, it is needed since state_dict keys are saved that way

  if torch.cuda.is_available():
    model.cuda()

  pretrained_weights  = torch.load(model_weights)

  if transfer_learning:

        
    psm_weights = model.state_dict()

    for key in psm_weights.keys():
      if key in pretrained_weights and pretrained_weights[key].shape == psm_weights[key].shape:
      
        psm_weights[key] = pretrained_weights[key]
    
    model.load_state_dict(psm_weights)

  else:
            
      model.load_state_dict(pretrained_weights['state_dict'])

      
  model_parameters = filter(lambda p: p.requires_grad, model.parameters())
  params = sum([np.prod(p.size()) for p in model_parameters])

  print(f"Number of trainable parameters in model: {params}")
  
  return model

In [None]:
def compute_3px_err(pred_disp, disp_gt):

        index = np.argwhere(disp_gt>0)

        abs_diff = torch.zeros_like(disp_gt)

        abs_diff[index[0][:], index[1][:], index[2][:]] = np.abs(disp_gt[index[0][:], index[1][:], index[2][:]] - pred_disp[index[0][:], index[1][:], index[2][:]])

        correct = (abs_diff[index[0][:], index[1][:], index[2][:]] < 3)|(abs_diff[index[0][:], index[1][:], index[2][:]] < disp_gt[index[0][:], index[1][:], index[2][:]]*0.05)      
        
        torch.cuda.empty_cache()

        return 1-(float(torch.sum(correct))/float(len(index[0])))

def extract_resnet_features(resnet, imgL,imgR):

    if resnet is None:
      return None, None
    else:
      return resnet(imgL).cuda() , resnet(imgR).cuda()




def train(model,optimizer, imgL,imgR,disp_gt, resnet):

        model.train()
        res_fea_L, res_fea_R = extract_resnet_features(resnet, imgL, imgR)

        if resnet is None:

          imgL   = imgL.cuda()
          imgR   = imgR.cuda()
        disp_gt = disp_gt.cuda()


        #---------
        mask = (disp_gt > 0)
        mask.detach_()
        #----

        optimizer.zero_grad()

        
    
        output1, output2, output3 = model(imgL,imgR,res_fea_L, res_fea_R)
        output1 = torch.squeeze(output1,1)
        output2 = torch.squeeze(output2,1)
        output3 = torch.squeeze(output3,1)

        loss = 0.5*F.smooth_l1_loss(output1[mask], disp_gt[mask], size_average=True) +0.7*F.smooth_l1_loss(output2[mask], disp_gt[mask], size_average=True) +F.smooth_l1_loss(output3[mask], disp_gt[mask], size_average=True) 


        loss.backward()
        optimizer.step()

        return loss.item()

def validate(model, imgL,imgR,disp_gt,resnet):

        model.eval()

        res_fea_L, res_fea_R = extract_resnet_features(resnet, imgL, imgR)

        if resnet is None:
          imgL   = imgL.cuda()
          imgR   = imgR.cuda()
        

        with torch.no_grad():
            output3 = model(imgL,imgR, res_fea_L,res_fea_R)

        pred_disp = output3.data.cpu().squeeze(1)


        return compute_3px_err(pred_disp, disp_gt)
        



def train_loop(model, trainDataLoader, valDataLoader, optimizer, transfer_learning = False, epochs = 40, save = False):
  
  least_err = 999999
  best_epoch = 0
  start_full_time = time.time()
  train_loss_per_epoch = []
  val_err_per_epoch = []

  resnet = None

  save_path = root_path + 'finetuned_models/'

  if transfer_learning:
    
    resnet = resnet18(pretrained = True).cpu()
      
    resnet = nn.Sequential(*list(resnet.children())[:-2])

    for name, param in resnet.named_parameters():

        param.requires_grad = False

  for epoch in range(1, epochs+1):

      total_train_loss = 0
      total_val_error = 0
          
      for batch_idx, (imgL_crop, imgR_crop, disp_gt_crop) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader)):
  
          loss = train(model, optimizer, imgL_crop,imgR_crop, disp_gt_crop,resnet)
          
          total_train_loss += loss
      
  

      for batch_idx, (imgL, imgR, disp_gt) in  enumerate(valDataLoader):

          val_error = validate(model,imgL,imgR, disp_gt,resnet)
         
          total_val_error += val_error

      err = total_val_error/len(valDataLoader)*100
      running_loss = total_train_loss/len(trainDataLoader)
      train_loss_per_epoch.append(running_loss)
      val_err_per_epoch.append(err)

      print('epoch %d running training loss = %.3f avg validation 3px-error = %.3f' %(epoch,running_loss ,err))
      if err < least_err:
          least_err = err
          best_epoch = epoch

          if save:
            
          
            model_name = 'resnet_finetune_'

            savefilename = save_path+model_name+str(epoch)+'.tar'
            torch.save({
                  'epoch': epoch,
                  'state_dict': model.state_dict(),
                  'train_loss': running_loss,
                  'val_error_3px': err,
              }, savefilename)

  np.save(save_path+model_name+'train_loss.npy',train_loss_per_epoch)
  np.save(save_path+model_name+'val_error.npy',val_err_per_epoch)    

  print('full finetune time = %.2f HR' %((time.time() - start_full_time)/3600))
  print('Best epoch %d Best Avg error 3px = %.3f' %(best_epoch, least_err))


In [None]:
test_path = root_path + '/dataset/data_stereo_flow/testing/'
train_path = root_path + 'dataset/data_stereo_flow/training/'

In [None]:
trainDataLoader = torch.utils.data.DataLoader(
         KITTI(train_path, mode = 'train'),
         batch_size= 4, shuffle= False, drop_last=False)

valDataLoader = torch.utils.data.DataLoader(
         KITTI(train_path, mode = 'val'), 
         batch_size= 8, shuffle= False, drop_last=False)

testDataLoader = torch.utils.data.DataLoader(
         KITTI(train_path, mode = 'test'),
         batch_size= 8, shuffle= False, drop_last=False)

In [None]:
transfer_learning = True
model_weights_path = root_path + 'model_weights/pretrained_sceneflow_new.tar'
maximum_disp = 192

torch.manual_seed(5)

if torch.cuda.is_available():
  torch.cuda.manual_seed(17)
model = get_model(maximum_disp,transfer_learning, model_weights_path)
optimizer = optim.Adam(model.parameters(), lr=10, betas=(0.9, 0.999))

Number of trainable parameters in model: 3453376


In [None]:
train_loop(model, trainDataLoader, valDataLoader, optimizer,transfer_learning,40,False)

In [None]:
def test(model, testDataLoader, resnet):
  
      for batch_idx, (imgL, imgR, disp_gt) in  enumerate(testDataLoader):

          test_error = validate(model,imgL,imgR, disp_gt,resnet)
         
          total_test_error += test_error

      avg_test_err = total_test_error/len(total_test_error)*100

      print(f"average test 3px error: {avg_test_err}")
      return avg_test_err
  

In [None]:
test(model, testDataLoader, resnet)