In [None]:
import os

from datetime import datetime
import numpy as np
import pandas as pd
import random

from PIL import Image
import cv2
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album

In [None]:
from bs4 import BeautifulSoup
from urllib.request import urlopen

def downloadImages(url, path):
    website = urlopen(url)
    html    = website.read()
    
    bs_parsed = BeautifulSoup(html,"html5lib")
    
    for image_id, link in enumerate(bs_parsed.find_all("a",href=True)):
        if image_id == 0:
            image_url = link["href"]
            if not os.path.isfile(path+"img-%d.png"%image_id):
                image = Image.open(urlopen(image_url))
                image.save(path+"img-%d.png" % image_id,"PNG")
    

In [None]:
os.makedirs('images/train/input')
os.makedirs('images/train/target')
os.makedirs('images/test/input')
os.makedirs('images/test/target')
os.makedirs('images/val/input')
os.makedirs('images/val/output')
TRAIN_IMAGE_URL="https://www.cs.toronto.edu/~vmnih/data/mass_roads/train/sat/index.html"
TRAIN_TARGET_URL="https://www.cs.toronto.edu/~vmnih/data/mass_roads/train/map/index.html"
TEST_IMAGE_URL="https://www.cs.toronto.edu/~vmnih/data/mass_roads/test/sat/index.html"
TEST_TARGET_URL="https://www.cs.toronto.edu/~vmnih/data/mass_roads/test/map/index.html"
VAL_IMAGE_URL="https://www.cs.toronto.edu/~vmnih/data/mass_roads/val/sat/index.html"
VAL_TARGET_URL="https://www.cs.toronto.edu/~vmnih/data/mass_roads/val/map/index.html"
downloadImages(url=TRAIN_IMAGE_URL, path="images/train/input")
downloadImages(url=TRAIN_TARGET_URL, path="images/train/target")
downloadImages(url=TEST_IMAGE_URL, path="images/test/input")
downloadImages(url=TEST_TARGET_URL, path="images/test/target")
downloadImages(url=VAL_IMAGE_URL, path="images/val/input")
downloadImages(url=VAL_TARGET_URL, path="images/val/target")

In [None]:
DATA_DIR = '/tiff'

x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'val_labels')

x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'test_labels')

In [None]:
class_names = ['background','road']
class_rgb   = [[0,0,0],[255,255,255]]

In [None]:
def encoder(label,label_values):
    semantic_map = []
    for color in label_values:
        equality     = np.equal(label,color)
        class_map    = np.all(equality,axis=-1)
        semantic_map.append(class_map)
    return np.stack(semantic_map, axis=-1)

def decoder(image):
    return np.argmax(image,axis=-1)


def color_code_seg(image,label_values):
    color_codes=np.array(label_values)
    return color_codes[image.astype(int)]
    

In [None]:
class Dataset(torch.utils.data.Dataset):
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            class_rgb_values=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        
        self.image_paths = [os.path.join(images_dir, image_id) for image_id in sorted(os.listdir(images_dir))]
        self.mask_paths = [os.path.join(masks_dir, image_id) for image_id in sorted(os.listdir(masks_dir))]

        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        mask = encoder(mask, self.class_rgb_values).astype('float')        
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        return image, mask
        
    def __len__(self):
        return len(self.image_paths)

In [None]:
dataset = Dataset(x_train_dir, y_train_dir, class_rgb_values=class_rgb)

In [None]:
## Training and preprocessing steps from https://www.kaggle.com/balraj98/unet-resnet50-frontend-road-segmentation-pytorch

def training_augmentation():
    train_transform = [    
        album.RandomCrop(height=256, width=256, always_apply=True),
        album.OneOf(
            [
                album.HorizontalFlip(p=1),
                album.VerticalFlip(p=1),
                album.RandomRotate90(p=1),
            ],
            p=0.75,
        ),
    ]
    return album.Compose(train_transform)


def validation_augmentation():   
    test_transform = [
        album.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=0),
    ]
    return album.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

In [None]:
augmented_dataset = Dataset(
    x_train_dir, y_train_dir, 
    augmentation=training_augmentation(),
    class_rgb_values=class_rgb,
)

In [None]:
import segmentation_models_pytorch as smp

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = class_names
ACTIVATION = 'sigmoid' 

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [None]:
train_dataset = Dataset(
    x_train_dir, y_train_dir, 
    augmentation=training_augmentation(),
    preprocessing=preprocessing(preprocessing_fn),
    class_rgb_values=class_rgb,
)

valid_dataset = Dataset(
    x_valid_dir, y_valid_dir, 
    augmentation=validation_augmentation(), 
    preprocessing=preprocessing(preprocessing_fn),
    class_rgb_values=class_rgb,
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)

  cpuset_checked))


In [None]:
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.00008),
])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=5e-5,
)

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
for i in range(0, EPOCHS):
  print('\nEpoch: '+ str(i))
  train_logs = train_epoch.run(train_loader)
  valid_logs = valid_epoch.run(valid_loader)


Epoch: 0
train:   0%|          | 0/70 [00:00<?, ?it/s]

  cpuset_checked))
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


train: 100%|██████████| 70/70 [15:09<00:00, 13.00s/it, dice_loss - 0.3815, iou_score - 0.6017]
valid: 100%|██████████| 14/14 [00:16<00:00,  1.16s/it, dice_loss - 0.3532, iou_score - 0.6096]

Epoch: 1
train: 100%|██████████| 70/70 [17:26<00:00, 14.95s/it, dice_loss - 0.2808, iou_score - 0.8038]
valid: 100%|██████████| 14/14 [00:24<00:00,  1.77s/it, dice_loss - 0.2749, iou_score - 0.7981]

Epoch: 2
train: 100%|██████████| 70/70 [18:47<00:00, 16.11s/it, dice_loss - 0.2305, iou_score - 0.8707]
valid: 100%|██████████| 14/14 [00:23<00:00,  1.69s/it, dice_loss - 0.2375, iou_score - 0.826]

Epoch: 3
train: 100%|██████████| 70/70 [13:17<00:00, 11.40s/it, dice_loss - 0.1948, iou_score - 0.8929]
valid: 100%|██████████| 14/14 [00:15<00:00,  1.12s/it, dice_loss - 0.2104, iou_score - 0.8383]

Epoch: 4
train: 100%|██████████| 70/70 [11:39<00:00,  9.99s/it, dice_loss - 0.166, iou_score - 0.9049]
valid: 100%|██████████| 14/14 [00:18<00:00,  1.29s/it, dice_loss - 0.1805, iou_score - 0.854]
CPU times: us

In [None]:
torch.save(model,"model.pth")