<a href="https://colab.research.google.com/github/davidenko2000/stereo-matching-CNN/blob/main/Evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
PATCH_SIZE = 9
MAX_DISP = 229
BATCH_SIZE = 128

PIXEL_ERROR = 3

LR = 0.001
LR_CHANGE_AT_EPOCH = 10
LR_AFTER_10_EPOCHS = 0.0001


FEATURES = 64
MARGIN = 0.2
EPOCHS = 20

In [None]:
IMAGES_DIR = 'drive/MyDrive/data_scene_flow/'

DISP_DIR = IMAGES_DIR + 'disparity/'
RGB_DIR = IMAGES_DIR + 'RGB/'
GRAY_DIR = IMAGES_DIR + 'GRAY/'

RGB_LEFT_DIR = RGB_DIR + 'left/'
RGB_RIGHT_DIR = RGB_DIR + 'right/'
GRAY_LEFT_DIR = GRAY_DIR + 'left/'
GRAY_RIGHT_DIR = GRAY_DIR + 'right/'

IS_GRAY = False

PATCH_SIZE = 9
MAX_DISPARITY = 229 

NUM_IMAGES = 200
TRAIN_START = 0
TRAIN_END = 160
VALID_START = 160
VALID_END = 200

TRAIN_DATA = IMAGES_DIR + 'train/'
VALID_DATA = IMAGES_DIR + 'valid/'

In [None]:
import os
import numpy as np
import skimage
import torch
from matplotlib import image as mpimg
from torch import nn
import torchvision.datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image

#Method returns a filename which is made of prefix, which length is 6
def get_filename(idx):
      return str.zfill(str(idx), 6) + "_10.png"

# Method returns an disparity image at index (HxW)
def get_disp_image(idx):
      return skimage.util.img_as_ubyte(mpimg.imread(DISP_DIR + get_filename(idx)))

# Method returns an image at index (HxWxC)
def get_image(idx, is_left):
    	return mpimg.imread((GRAY_LEFT_DIR if IS_GRAY else RGB_LEFT_DIR) + get_filename(idx)) if is_left else mpimg.imread((GRAY_RIGHT_DIR if IS_GRAY else RGB_RIGHT_DIR) + get_filename(idx))

#Method which computes mean and standard deviation, used for normalization
def get_mean_std():
      RGB_transform = transforms.Compose([
              transforms.Resize(256),
              transforms.CenterCrop(256),
              transforms.ToTensor()
          ])
      GRAY_transform = transforms.Compose([
          transforms.Resize(256),
          transforms.CenterCrop(256),
          transforms.Grayscale(),
          transforms.ToTensor()
      ])

      images_data = torchvision.datasets.ImageFolder(root=(GRAY_DIR if IS_GRAY else RGB_DIR),
                transform=(GRAY_transform if IS_GRAY else RGB_transform))
      data_loader = DataLoader(images_data, batch_size=len(images_data), shuffle=False, num_workers=0)
      images, _ = next(iter(data_loader))
      mean, std = images.mean([0, 2, 3]), images.std([0, 2, 3])
      return mean, std

#Method returns disparity file, which consists of narrays
def load_disparity_data(train=True):
    	return np.load((TRAIN_DATA if train else VALID_DATA) + 'disparities.npy')

#Method which converts RGB images to grayscale
def convert_to_grayscale():
      for idx in range(NUM_IMAGES):
            filename = get_filename(idx)
            Image.open(RGB_LEFT_DIR + filename).convert("L").save(GRAY_LEFT_DIR + filename)
            Image.open(RGB_RIGHT_DIR + filename).convert("L").save(GRAY_RIGHT_DIR + filename)

#Method which makes narrays and saves disparity data to file
def make_disparity_data(train=True):
      
      def filter(idx):
              distance = PATCH_SIZE // 2
              disp_image = get_disp_image(idx)
              rows, cols = disp_image.nonzero() #returns non zero pixels (pixels with known disparity)
              rows = rows.astype(np.uint16)
              cols = cols.astype(np.uint16) 
              disparity_values = disp_image[rows, cols]
              
              pos_cols = cols - disparity_values #computes cols with correct disparity
              neg_cols = pos_cols + np.random.choice([-8, -7, -6, -5, -4, 4, 5, 6, 7, 8], size=pos_cols.size).astype(np.uint16) #computes cols with incorrect disparity

              #CFilters which will be used to discard changes to a pixel which are not allowed
              filterR = (rows >= distance) & (rows < disp_image.shape[0] - distance)
              filterC = (cols >= distance) & (cols < disp_image.shape[1] - distance)
              filterPC = (pos_cols >= distance) & (pos_cols < disp_image.shape[1] - distance)
              filterNC = (neg_cols >= distance) & (neg_cols < disp_image.shape[1] - distance)

              main_filter = filterR & filterC & filterPC & filterNC
              #Making narray of image indexes and corresponding rows, cols, pos_cols and neg_cols
              newR = rows[main_filter]
              newC = cols[main_filter]
              newPC = pos_cols[main_filter]
              newNC = neg_cols[main_filter]
              
              result = np.empty(len(newR), dtype=np.dtype([('idx', 'uint8'), ('row', 'uint16'), ('col', 'uint16'), ('col_pos', 'uint16'), ('col_neg', 'uint16'), ]))
              result['idx'] = np.full(newR.shape, idx, dtype=np.uint8)
              result['row'] = newR
              result['col'] = newC
              result['col_pos'] = newPC
              result['col_neg'] = newNC

              return result

      disparities = np.concatenate([filter(idx) for idx in range(TRAIN_START if train else VALID_START, TRAIN_END if train else VALID_END)])
      np.save((TRAIN_DATA if train else VALID_DATA) + 'disparities.npy', disparities)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class StereoCNN(nn.Module):
	def __init__(self, in_channels=3, features=64, ksize=3, padding=1):
          super().__init__()
          self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=ksize, padding=padding)
          self.conv2 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=ksize, padding=padding)
          self.conv3 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=ksize, padding=padding)
          self.conv4 = nn.Conv2d(in_channels=features, out_channels=features, kernel_size=ksize, padding=padding)
	def forward(self, x):
          x = F.relu(self.conv1(x))
          x = F.relu(self.conv2(x))
          x = F.relu(self.conv3(x))
          x = self.conv4(x)
          x = x.squeeze(3).squeeze(2) 
          return x

In [None]:
import numpy as np
import torch
from matplotlib import pyplot as plt, colors
from torchvision.transforms import ToTensor 

def plot_acc_byimage(idx, model):
	real_disp = get_disp_image(idx)
	predicted_disp = compute_disparity_map(idx, model)
	acc = disparity_accuracy_byimage(real_disp=real_disp, predicted_disp=predicted_disp)

	plt.figure(figsize=(20, 10))
	color_map = colors.ListedColormap(['black', 'blue', 'green'])
	plt.imshow(acc, cmap=color_map)
 
#Method which uses real disparity, predicted disparity and pixel error to calcuate accuracy of prediction.
#Returns:
#        -1 -> if the real disparity is unknown
#        0  -> if the prediction in incorrect
#        1  -> if the prediction is correct (the prediction must be in the interval which uses allowed pixel error
#
def disparity_accuracy_byimage(real_disp, predicted_disp, pxl_error=PIXEL_ERROR):
	acc = np.zeros(real_disp.shape)
	acc[real_disp == 0] = -1
	acc[(real_disp != 0) & (np.abs(predicted_disp - real_disp) < pxl_error)] = 1

	return acc

#Method which calculates accuracy percentage, by using real and predicted disparity.
def compute_accuracy(model, train=True):
	counter_correct = 0
	counter_total = 0
	pxl_error=PIXEL_ERROR
	for idx in range(TRAIN_START if train else VALID_START,TRAIN_END if train else VALID_END):
		real_disp = get_disp_image(idx)
		predicted_disp = compute_disparity_map(idx, model)
		counter_correct += np.count_nonzero((real_disp != 0) & (np.abs(predicted_disp - real_disp) < pxl_error))
		counter_total += np.count_nonzero(real_disp)

	return counter_correct / counter_total

#Method which computes a disparity map of the left and right image
def compute_disparity_map(idx, model):
      max_disp = MAX_DISP
      model.eval()

      left_img = get_image(idx, is_left=True)
      right_img = get_image(idx, is_left=False)

      #Adding the padding, as a result the output image will have the same dimensions as input
      padding = PATCH_SIZE // 2
      left_pad = np.pad(left_img, ((padding, padding), (padding, padding), (0, 0)))
      right_pad = np.pad(right_img, ((padding, padding), (padding, padding), (0, 0)))

      left = ToTensor()(left_pad).unsqueeze(0)
      right = ToTensor()(right_pad).unsqueeze(0)
      left, right = left.to(device), right.to(device)
      #Making a ndarray of the tensor, which is the output of the model
      left_tensor = model(left)
      right_tensor = model(right)

      left_array = left_tensor.cpu().squeeze(0).permute(1, 2, 0).detach().numpy()
      right_array = right_tensor.cpu().squeeze(0).permute(1, 2, 0).detach().numpy()

      #ndarray[HxWxD]
      stacked_disp =  np.stack([np.sum(left_array * np.roll(right_array, d, axis=1) , axis=2) for d in range(max_disp)], axis=2)
      #ndarray[HxW]-using argmax to extract the most similar
      disp_map =  np.argmax(stacked_disp, axis=2)

      return disp_map

def plot_images(model, idx):
    real_disparity = get_disp_image(idx)
    plt.subplot(2, 1, 1)
    plt.imshow(real_disparity)
    plt.title(f'real disparity {idx}')
    predicted_disparity = compute_disparity_map(idx, model)
    plt.subplot(2, 1, 2)
    plt.imshow(predicted_disparity)
    plt.title(f'predicted disparity {idx}')
    plt.show()

    plot_acc_byimage(idx, model)
    plt.title(f"accuracy image {idx}")
    plt.show()

    plt.subplot(2, 1, 1)
    plt.imshow(get_image(idx, True))
    plt.title(f"left image {idx}")
    plt.subplot(2, 1, 2)
    plt.imshow(get_image(idx, False))
    plt.title(f"right image {idx}")
    plt.show()

def evaluate_model(model):
    train_accuracy = compute_accuracy(model, train=True)
    test_accuracy = compute_accuracy(model, train=False)
    print(f"train accuracy = {train_accuracy}\ntest accuracy = {test_accuracy}")

model = torch.load(TRAIN_DATA + f"train_model_{EPOCHS-1}.pth")
modelGRAY = torch.load(TRAIN_DATA + f"train_modelGRAY_{EPOCHS-1}.pth")
device = 'cuda'
model.to(device)

img_idx = 187
plot_images(model, img_idx)
evaluate_model(model)
