<a href="https://colab.research.google.com/github/davidenko2000/stereo-matching-CNN/blob/main/StereoCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
PATCH_SIZE = 9
MAX_DISP = 192
BATCH_SIZE = 128

PIXEL_ERROR = 3

LR = 0.001
LR_CHANGE_AT_EPOCH = 11
LR_AFTER_10_EPOCHS = 0.0001


FEATURES = 64
MARGIN = 0.2
EPOCHS = 14

In [None]:
IMAGES_DIR = 'drive/MyDrive/data_scene_flow/'

DISP_DIR = IMAGES_DIR + 'disparity/'
RGB_DIR = IMAGES_DIR + 'RGB/'
GRAY_DIR = IMAGES_DIR + 'GRAY/'

RGB_LEFT_DIR = RGB_DIR + 'left/'
RGB_RIGHT_DIR = RGB_DIR + 'right/'
GRAY_LEFT_DIR = GRAY_DIR + 'left/'
GRAY_RIGHT_DIR = GRAY_DIR + 'right/'

IS_GRAY = False
BNORM = False

PATCH_SIZE = 9
MAX_DISPARITY = 100 #povecat mozda

NUM_IMAGES = 200
TRAIN_START = 0
TRAIN_END = 160
VALID_START = 160
VALID_END = 200

TRAIN_DATA = './train/'
VALID_DATA = './valid/'

In [None]:
import os
import numpy as np
import skimage
import torch
from matplotlib import image as mpimg
from torch import nn
import torchvision.datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

#Method returns a filename which is made of prefix, which length is 6
def get_filename(idx):
      return str.zfill(str(idx), 6) + "_10.png"

# Method returns an disparity image at index (HxW)
def get_disp_image(idx):
      return skimage.util.img_as_ubyte(mpimg.imread(DISP_DIR + get_filename(idx)))

# Method returns an image at index (HxWxC)
def get_image(idx, is_left):
    	return mpimg.imread((GRAY_LEFT_DIR if IS_GRAY else RGB_LEFT_DIR) + get_filename(idx)) if is_left else mpimg.imread((GRAY_RIGHT_DIR if IS_GRAY else RGB_RIGHT_DIR) + get_filename(idx))
  
#Method which computes a disparity map of the left and right image
def compute_disparity_map(idx, model):
      max_disp = MAX_DISP
      model.to('cuda').eval()

      left_img = get_image(idx, is_left=True)
      right_img = get_image(idx, is_left=False)

      #Adding the padding, as a result the output image will have the same dimensions as input
      padding = PATCH_SIZE // 2
      left_pad = transforms.ToTensor(np.pad(left_img, ((padding, padding),(padding, padding)))).unsqueeze(0)
      right_pad = transforms.ToTensor(np.pad(right_img, ((padding, padding),(padding, padding)))).unsqueeze(0)

      #Making a ndarray of the tensor, which is the output of the model
      left_array = left_pad.squeeze(0).permute(1, 2, 0).detach().numpy()
      right_array = right_pad.squeeze(0).permute(1, 2, 0).detach().numpy()

      #ndarray[HxWxD]
      stacked_disp =  np.stack([np.sum(left_array * np.roll(right_array, d, axis=1) , axis=2) for d in range(max_disp)], axis=2)
      #ndarray[HxW]-using argmax to extract the most similar
      disp_map =  np.argmax(stacked_disp, axis=2)

      return disp_map

#Method which computes mean and standard deviation, used for normalization
def get_mean_std():
      RGB_transform = transforms.Compose([
              transforms.Resize(256),
              transforms.CenterCrop(256),
              transforms.ToTensor()
          ])
      GRAY_transform = transforms.Compose([
          transforms.Resize(256),
          transforms.CenterCrop(256),
          transforms.Grayscale(),
          transforms.ToTensor()
      ])

      images_data = torchvision.datasets.ImageFolder(root=(GRAY_DIR if IS_GRAY else RGB_DIR),
                transform=(GRAY_transform if IS_GRAY else RGB_transform))
      data_loader = DataLoader(images_data, batch_size=len(images_data), shuffle=False, num_workers=0)
      images, _ = next(iter(data_loader))
      mean, std = images.mean([0, 2, 3]), images.std([0, 2, 3])
      return mean, std

#Method returns disparity file, which consists of narrays
def load_disparity_data(train=True):
    	return np.load((TRAIN_DATA if train else VALID_DATA) + 'disparities.npy')

#Method which converts RGB images to grayscale
def convert_to_grayscale():
      for idx in range(NUM_IMAGES):
            filename = get_filename(idx)
            Image.open(RGB_LEFT_DIR + filename).convert("L").save(GRAY_LEFT_DIR + filename)
            Image.open(RGB_RIGHT_DIR + filename).convert("L").save(GRAY_RIGHT_DIR + filename)

#Method which makes narrays and saves disparity data to file
def make_disparity_data(train=True):
      distance = PATCH_SIZE // 2
      def filter(idx):
              disp_image = get_disp_image(idx)
              rows, cols = disp_image.nonzero() #returns non zero pixels (pixels with known disparity)
              rows = rows.astype(np.uint16)
              cols = cols.astype(np.uint16) 
              disparity_values = disp_image[rows, cols]
              
              pos_cols = cols - disparity_values #computes cols with correct disparity
              neg_cols = pos_cols + np.random.choice([-8, -7, -6, -5, -4, 4, 5, 6, 7, 8], size=pos_cols.size).astype(np.uint16) #computes cols with incorrect disparity

              #Calibrations which will be used to discard changes to a pixel which are not allowed
              calibrate_rows = (rows >= distance) & (rows < disp_image.shape[0] - distance)
              calibrate_cols = (cols >= distance) & (cols < disp_image.shape[1] - distance)
              calibrate_pos_cols = (pos_cols >= distance) & (pos_cols < disp_image.shape[1] - distance)
              calibrate_neg_cols = (neg_cols >= distance) & (neg_cols < disp_image.shape[1] - distance)
              calibrations = calibrate_rows & calibrate_cols & calibrate_pos_cols & calibrate_neg_cols

              #Making narray of image indexes and corresponding rows, cols, pos_cols and neg_cols
              rows = rows[calibrations]
              cols = cols[calibrations]
              pos_cols = pos_cols[calibrations]
              neg_cols = neg_cols[calibrations]

              result = np.empty(len(rows), dtype=np.dtype([('idx', 'uint8'), ('row', 'uint16'), ('col', 'uint16'), ('col_pos', 'uint16'), ('col_neg', 'uint16'), ]))
              result['idx'] = np.full(rows.shape, idx, dtype=np.uint8)
              result['row'] = rows
              result['col'] = cols
              result['col_pos'] = pos_cols
              result['col_neg'] = neg_cols

              return result

      disparities = np.concatenate([filter(idx) for idx in range(TRAIN_START if train else VALID_START, TRAIN_END if train else VALID_END)])
      np.save((TRAIN_DATA if train else VALID_DATA) + 'disparities.npy', disparities)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class StereoCNN(nn.Module):
	def __init__(self, in_channels=3, features=64, ksize=3, padding=1):
          super().__init__()
          self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=ksize, padding=padding)
          self.conv2 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=ksize, padding=padding)
          self.conv3 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=ksize, padding=padding)
          self.conv4 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=ksize, padding=padding)
	def forward(self, x):
          x = F.relu(self.conv1(x))
          x = F.relu(self.conv2(x))
          x = F.relu(self.conv3(x))
          x = self.conv4(x)
          x = x.squeeze(3).squeeze(2) #dimensions of tensor at second and third dimension are 1, therefore they will be removed
          return  F.normalize(x) #it normalizes vector with euclidean norm

In [None]:
from torch.utils.data import Dataset

class PatchesExtractor(Dataset):

  def __init__(self, tform, train=True):
          self.transform = tform
          self.disparity_data = load_disparity_data(train=train)
          self.len = self.disparity_data.size
          self.left_images = {}
          self.right_images = {}
          
          for idx in range(TRAIN_START if train else VALID_START, TRAIN_END if train else VALID_END):
                        self.left_images[idx] = get_image(idx, is_left=True)
                        self.right_images[idx] = get_image(idx, is_left=False)
	#Method that extracts patches from particular image
  def __getitem__(self, patch_idx):
          patch_data = self.disparity_data[patch_idx]
          image_idx = patch_data['idx']
          row = patch_data['row']
          col = patch_data['col']
          col_pos = patch_data['col_pos']
          col_neg = patch_data['col_neg']

          patch_size = PATCH_SIZE
          left_image = self.left_images[image_idx]
          right_pos_image = self.right_images[image_idx]
          rigth_neg_image = self.right_images[image_idx]

          if self.transform:
              left_patch = self.transform(left_image[(row - patch_size // 2):(row + patch_size // 2), (col - patch_size // 2):(col + patch_size // 2)])
              right_positive_patch = self.transform(right_pos_image[(row - patch_size // 2):(row + patch_size // 2), (col_pos - patch_size // 2):(col_pos + patch_size // 2)])
              right_negative_patch = self.transform(rigth_neg_image[(row - patch_size // 2):(row + patch_size // 2), (col_neg - patch_size // 2):(col_neg + patch_size // 2)])
                                                    
          return left_patch, right_positive_patch, right_negative_patch
  def __len__(self):
    return self.len

In [None]:
make_disparity_data(train=True)
make_disparity_data(train=False)


In [None]:
device = 'cuda'

model = StereoCNN(in_channels=3)
model.to(device)

RGB_transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.4817, 0.5085, 0.5006), (0.3107, 0.3249, 0.3350))
	])
GRAY_transform = transforms.Compose([
	transforms.ToTensor(),
    	transforms.Normalize((0.4999), (0.3180))
	])

train_dataset = PatchesExtractor(tform=RGB_transform,train=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

criterion = nn.TripletMarginLoss(margin=MARGIN)
optimizer = torch.optim.Adam(params=model.parameters(), lr=LR)

train_report = np.zeros((EPOCHS, ))
num_batches = 0

def train(epoch):
    print('\nEpoch: %d' % epoch)
    train_loss = 0
    epoch_batches = 0
    model.train()
    for left_patch, right_pos_patch, right_neg_patch in train_dataloader:
      global num_batches
      num_batches += 1
      epoch_batches += 1
      left_patch, right_pos_patch, right_neg_patch = left_patch.to(device), right_pos_patch.to(device), right_neg_patch.to(device)
      left_output, right_pos_output, right_neg_output = model(left_patch), model(right_pos_patch), model(right_neg_patch)

      optimizer.zero_grad()
      loss = criterion(left_output, right_pos_output, right_neg_output)
      loss.backward()
      optimizer.step()
      train_loss += loss.item()

      if epoch_batches % 1000 == 0:
        print(f"Done {((BATCH_SIZE * epoch_batches) / len(train_dataset)) * 100:.3f} %")
        print('Train -> Loss: %.3f \n' % (train_loss/ epoch_batches))

    train_report[epoch] = train_loss / epoch_batches
    torch.save(model, f"train_model_{epoch}.pth")


for epoch in range(EPOCHS):
    train(epoch)
    if epoch == LR_CHANGE_AT_EPOCH:
            for param in optimizer.param_groups:
                  param['lr'] = LR_AFTER_10_EPOCHS

np.save(TRAIN_DATA + 'training_report.npy', train_report)