<a href="https://colab.research.google.com/github/dchung117/self-driving-cars-cv/blob/main/fcn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

 # Package installations

In [1]:
!pip install -q kaggle

# Imports

In [51]:
import pandas as pd
import numpy as np
import os
import random
import datetime
from base64 import b64encode
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms.functional as TF
from torch.utils.data import random_split, Dataset

import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.image as mpimg
from IPython.display import clear_output, HTML

# mount google colab drive
from google.colab import files

# Download KITTI dataset

In [4]:
# Upload kaggle json
files.upload()

!mkdir ~/.kaggle
!mv kaggle.json ~/.kaggle
!chmod 600 ~/.kaggle/kaggle.json

Saving kaggle.json to kaggle.json


In [5]:
# Download KITTI road segmentation data
!kaggle datasets download -d sakshaymahna/kittiroadsegmentation
!unzip -q kittiroadsegmentation.zip

Downloading kittiroadsegmentation.zip to /content
 97% 297M/305M [00:02<00:00, 72.0MB/s]
100% 305M/305M [00:02<00:00, 109MB/s] 


In [6]:
# Create validation split
RANDOM_STATE = 27

train_data_dir = Path("training")
test_data_dir = Path("testing")

val_split = 0.1
test_split = 0.1

# shuffle train data
train_data_dir_list = list((train_data_dir / "image_2").iterdir())
n = len(train_data_dir_list)
random.Random(RANDOM_STATE).shuffle(train_data_dir_list)

test_data_files = train_data_dir_list[:int(test_split*n)]
val_data_files = train_data_dir_list[int(test_split*n):int(test_split*n)+int(val_split*n)]
train_data_files = train_data_dir_list[int(test_split*n)+int(val_split*n):]

# Constants

In [52]:
IMG_SIZE = 128
N_CHANNELS = 3
N_CLASSES = 1
BATCH_SIZE = 32

# Datasets

In [12]:
class KITTIDataset(Dataset):
  def __init__(self, data_files: list):
    self.data_files = data_files

  def __len__(self):
    return len(self.data_files)

  def __getitem__(self, idx: int):
    # Get image file
    file = str(self.data_files[idx])
  
    # Read img
    img = torchvision.io.read_image(file)

    # Read mask
    mask_file = file.replace("image_2", "gt_image_2")
    mask_file = mask_file.replace("um_", "um_road_")
    mask_file = mask_file.replace("umm_", "umm_road_")
    mask_file = mask_file.replace("uu_", "uu_road_")
    mask = torchvision.io.read_image(mask_file)

    # labels for non-road, road, other
    # non_road_label = torch.tensor([255, 0, 0], dtype=torch.uint8).unsqueeze(-1).unsqueeze(-1)
    road_label = torch.tensor([255, 0, 255], dtype=torch.uint8).unsqueeze(-1).unsqueeze(-1)
    # other_label = torch.tensor([0, 0, 0], dtype=torch.uint8).unsqueeze(-1).unsqueeze(-1)

    # Convert to binary mask
    mask = torch.all((mask == road_label), dim=0).unsqueeze(dim=0).to(torch.uint8)

    return img, mask

In [9]:
# Create datasets
train_dataset = KITTIDataset(train_data_files)
val_dataset = KITTIDataset(val_data_files)
test_dataset = KITTIDataset(test_data_files)

# Image Augmentations

In [58]:
def normalize(image: torch.Tensor, mask: torch.Tensor) -> tuple:
  # cast image as float, normalize
  image = image.to(torch.float32) / 255.0
  return image, mask

# train collate function
def train_collator(batch: list) -> tuple:
  def collate_fn(image: torch.Tensor, mask: torch.Tensor) -> tuple:
    # resize tensors
    image = TF.resize(image, (IMG_SIZE, IMG_SIZE))
    mask = TF.resize(mask, (IMG_SIZE, IMG_SIZE))

    # horizontal flip w/ 0.5 prob
    if np.random.uniform() < 0.5:
      image = TF.hflip(image)
      mask = TF.hflip(mask)
    
    # normalize image, mask
    image, mask = normalize(image, mask)

    return image, mask

  # Apply collate function to each element in batch
  batch = [collate_fn(*x) for x in batch]

  # Unpack batch
  images, masks = zip(*batch)
  images = torch.stack(images, dim=0)
  masks = torch.stack(masks, dim=0)

  return images, masks

def test_collator(batch: list) -> tuple:
  def collate_fn(image: torch.Tensor, mask: torch.Tensor) -> tuple:
    # resize tensors
    image = TF.resize(image, (IMG_SIZE, IMG_SIZE))
    mask = TF.resize(image, (IMG_SIZE, IMG_SIZE))

    # normalize
    image, mask = normalize(image, mask)

    return image, mask
  
  # resize and normalize
  batch = [collate_fn(*x) for x in batch]

  # unpack batch
  images, masks = zip(*batch)
  images = torch.stack(images, dim=0)
  masks = torch.stack(masks, dim=0)

  return images, masks

# Dataloaders

In [61]:
train_dl = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=train_collator, pin_memory=True)
val_dl = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                    collate_fn=test_collator, pin_memory=True)
test_dl = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                     collate_fn=test_collator, pin_memory=True)