In [None]:
import os
import json
from PIL import Image
import numpy as np
import cv2 as cv
import torch
import torch.nn as nn
import torch.optim as optim
from torch.functional import F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from tqdm import tqdm

In [None]:
# base paths
ROOT_DIR = os.path.join(os.getcwd(), os.pardir)
DATA_DIR = os.path.join(ROOT_DIR, 'data')
TRAIN_DIR = os.path.join(DATA_DIR, 'unet_segmentation_train', 'road_marking')
IMG_DIR = os.path.join(TRAIN_DIR, 'img')
MASK_DIR = os.path.join(TRAIN_DIR, 'ann')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class UNetDataset(Dataset):
    def __init__(self, img_dir, mask_dir, height, width, transform=None, device='cpu'):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_filenames = []
        self.mask_filenames = []
        self.height = height
        self.width = width
        self.transform = transform

        # get img and mask annotation filenames 
        for filename in os.listdir(img_dir):
            img_path = os.path.join(img_dir, filename)
            mask_path = os.path.join(mask_dir, f'{filename}.json')            
            self.img_filenames.append(img_path)
            self.mask_filenames.append(mask_path)

    def __getitem__(self, idx):
        img = self.get_img(idx)
        mask = self.get_mask(idx)

        return img, mask

    def __len__(self):
        return len(self.img_filenames)
    
    def get_img(self, idx):
        """Loads image."""
        img = cv.imread(self.img_filenames[idx])
        img = cv.resize(img, (self.width, self.height), interpolation=cv.INTER_CUBIC)
        img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        transform = transforms.Compose([
            transforms.ToTensor()
        ])

        img = transform(img)
        img /= 255.0

        return img.to(DEVICE)
    
    def get_mask(self, idx):
        """Loads mask and returns binary mask."""
        # get json mask
        with open(self.mask_filenames[idx], 'rb') as file:
            mask = json.load(file)
        
        mask = self.process_mask(mask)

        # fill binary mask
        binary_mask = np.zeros((self.height, self.width), dtype=np.uint8)
        for polygone in mask:
            points = np.array(polygone, dtype=np.int32)
            cv.fillPoly(binary_mask, [points], color=1)
        
        return torch.tensor(binary_mask, dtype=torch.float32).to(DEVICE)

    def process_mask(self, mask):
        """Gets polygones from mask dict."""
        objects = mask['objects']

        polygones = []
        for polygone in objects:
            points = polygone['points']['exterior']
            polygones.append(points)

        return polygones

In [None]:
class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()
                
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)        

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        
        
        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)
        
        self.conv_last = nn.Conv2d(64, n_class, 1)
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   
        
        x = self.dconv_down4(x)
        
        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)
        
        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   
        
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        out = F.sigmoid(out)
        
        return out
    

def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )

In [None]:
augmentation = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=45),
])

train_dataset = UNetDataset(
    img_dir=IMG_DIR, 
    mask_dir=MASK_DIR, 
    height=720, 
    width=400,
    transform=augmentation,
    device=DEVICE
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=2,
    drop_last=True
)

In [None]:
NUM_EPOCHS = 2
torch.manual_seed(42)

model = UNet(n_class=1).to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=3e-5)

for epoch in range(NUM_EPOCHS):
    train_pbar = tqdm(iterable=train_dataloader, desc=f'Epoch {epoch}')
    for batch in train_pbar:
        img, mask = batch
        output = model(img)
        output = output.squeeze(1)
        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        train_pbar.set_postfix({'loss': loss.item()})

In [None]:
img = cv.imread('../data/screenshots/402.jpg')
img = cv.resize(img, (1920, 1080), interpolation=cv.INTER_CUBIC)
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
img = Image.fromarray(img)
img_cp = img.copy()

transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], 
    #                      std=[0.229, 0.224, 0.225])
])

img = transform(img)
img = img.to(DEVICE)

In [None]:
img_cp

In [None]:
predict = model(img.unsqueeze(0))
predict = predict.squeeze(0, 1)

In [None]:
import torchvision.transforms as transforms
to_pil_transform = transforms.ToPILImage()
img = to_pil_transform(predict)

In [None]:
img